Commit 02c2e162 authored by astaxie's avatar astaxie

Strengthens the session's function

parent 59a67720
package session package session
import ( import (
"errors"
"io"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
...@@ -48,6 +50,14 @@ func (fs *FileSessionStore) Delete(key interface{}) error { ...@@ -48,6 +50,14 @@ func (fs *FileSessionStore) Delete(key interface{}) error {
return nil return nil
} }
func (fs *FileSessionStore) Flush() error {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.values = make(map[interface{}]interface{})
fs.updatecontent()
return nil
}
func (fs *FileSessionStore) SessionID() string { func (fs *FileSessionStore) SessionID() string {
return fs.sid return fs.sid
} }
...@@ -121,6 +131,55 @@ func (fp *FileProvider) SessionGC() { ...@@ -121,6 +131,55 @@ func (fp *FileProvider) SessionGC() {
filepath.Walk(fp.savePath, gcpath) filepath.Walk(fp.savePath, gcpath)
} }
func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777)
if err != nil {
println(err.Error())
}
err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
if err != nil {
println(err.Error())
}
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var newf *os.File
if err == nil {
return nil, errors.New("newsid exist")
} else if os.IsNotExist(err) {
newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
}
_, err = os.Stat(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid))
var f *os.File
if err == nil {
f, err = os.OpenFile(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid), os.O_RDWR, 0777)
io.Copy(newf, f)
} else if os.IsNotExist(err) {
newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
} else {
return nil, err
}
f.Close()
os.Remove(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])))
os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now())
var kv map[interface{}]interface{}
b, err := ioutil.ReadAll(newf)
if err != nil {
return nil, err
}
if len(b) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = decodeGob(b)
if err != nil {
return nil, err
}
}
newf, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_WRONLY|os.O_CREATE, 0777)
ss := &FileSessionStore{f: newf, sid: sid, values: kv}
return ss, nil
}
func gcpath(path string, info os.FileInfo, err error) error { func gcpath(path string, info os.FileInfo, err error) error {
if err != nil { if err != nil {
return err return err
......
...@@ -40,6 +40,13 @@ func (st *MemSessionStore) Delete(key interface{}) error { ...@@ -40,6 +40,13 @@ func (st *MemSessionStore) Delete(key interface{}) error {
return nil return nil
} }
func (st *MemSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.value = make(map[interface{}]interface{})
return nil
}
func (st *MemSessionStore) SessionID() string { func (st *MemSessionStore) SessionID() string {
return st.sid return st.sid
} }
...@@ -80,6 +87,29 @@ func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) { ...@@ -80,6 +87,29 @@ func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) {
return nil, nil return nil, nil
} }
func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
pder.lock.RLock()
if element, ok := pder.sessions[oldsid]; ok {
go pder.SessionUpdate(oldsid)
pder.lock.RUnlock()
pder.lock.Lock()
element.Value.(*MemSessionStore).sid = sid
pder.sessions[sid] = element
delete(pder.sessions, oldsid)
pder.lock.Unlock()
return element.Value.(*MemSessionStore), nil
} else {
pder.lock.RUnlock()
pder.lock.Lock()
newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
element := pder.list.PushBack(newsess)
pder.sessions[sid] = element
pder.lock.Unlock()
return newsess, nil
}
return nil, nil
}
func (pder *MemProvider) SessionDestroy(sid string) error { func (pder *MemProvider) SessionDestroy(sid string) error {
pder.lock.Lock() pder.lock.Lock()
defer pder.lock.Unlock() defer pder.lock.Unlock()
......
...@@ -50,6 +50,14 @@ func (st *MysqlSessionStore) Delete(key interface{}) error { ...@@ -50,6 +50,14 @@ func (st *MysqlSessionStore) Delete(key interface{}) error {
return nil return nil
} }
func (st *MysqlSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
st.updatemysql()
return nil
}
func (st *MysqlSessionStore) SessionID() string { func (st *MysqlSessionStore) SessionID() string {
return st.sid return st.sid
} }
...@@ -108,6 +116,28 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) { ...@@ -108,6 +116,28 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
return rs, nil return rs, nil
} }
func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=?", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
}
c.Exec("update session set `session_key`=? where session_key=?", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = decodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &MysqlSessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
func (mp *MysqlProvider) SessionDestroy(sid string) error { func (mp *MysqlProvider) SessionDestroy(sid string) error {
c := mp.connectInit() c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=?", sid) c.Exec("DELETE FROM session where session_key=?", sid)
......
...@@ -35,6 +35,11 @@ func (rs *RedisSessionStore) Delete(key interface{}) error { ...@@ -35,6 +35,11 @@ func (rs *RedisSessionStore) Delete(key interface{}) error {
return err return err
} }
func (rs *RedisSessionStore) Flush() error {
_, err := rs.c.Do("DEL", rs.sid)
return err
}
func (rs *RedisSessionStore) SessionID() string { func (rs *RedisSessionStore) SessionID() string {
return rs.sid return rs.sid
} }
...@@ -99,6 +104,16 @@ func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) { ...@@ -99,6 +104,16 @@ func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) {
return rs, nil return rs, nil
} }
func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
c := rp.connectInit()
if str, err := redis.String(c.Do("HGET", oldsid, oldsid)); err != nil || str == "" {
c.Do("HSET", oldsid, oldsid, rp.maxlifetime)
}
c.Do("RENAME", oldsid, sid)
rs := &RedisSessionStore{c: c, sid: sid}
return rs, nil
}
func (rp *RedisProvider) SessionDestroy(sid string) error { func (rp *RedisProvider) SessionDestroy(sid string) error {
c := rp.connectInit() c := rp.connectInit()
c.Do("DEL", sid) c.Do("DEL", sid)
......
package session package session
import ( import (
"crypto/hmac"
"crypto/md5"
"crypto/rand" "crypto/rand"
"crypto/sha1"
"encoding/base64" "encoding/base64"
"encoding/hex"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
...@@ -16,11 +20,13 @@ type SessionStore interface { ...@@ -16,11 +20,13 @@ type SessionStore interface {
Delete(key interface{}) error //delete session value Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID SessionID() string //back current sessionID
SessionRelease() // release the resource SessionRelease() // release the resource
Flush() error //delete all data
} }
type Provider interface { type Provider interface {
SessionInit(maxlifetime int64, savePath string) error SessionInit(maxlifetime int64, savePath string) error
SessionRead(sid string) (SessionStore, error) SessionRead(sid string) (SessionStore, error)
SessionRegenerate(oldsid, sid string) (SessionStore, error)
SessionDestroy(sid string) error SessionDestroy(sid string) error
SessionGC() SessionGC()
} }
...@@ -44,40 +50,91 @@ type Manager struct { ...@@ -44,40 +50,91 @@ type Manager struct {
cookieName string //private cookiename cookieName string //private cookiename
provider Provider provider Provider
maxlifetime int64 maxlifetime int64
hashfunc string //support md5 & sha1
hashkey string
options []interface{} options []interface{}
} }
//options
//1. is https default false
//2. hashfunc default sha1
//3. hashkey default beegosessionkey
//4. maxage default is none
func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) { func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) {
provider, ok := provides[provideName] provider, ok := provides[provideName]
if !ok { if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
} }
provider.SessionInit(maxlifetime, savePath) provider.SessionInit(maxlifetime, savePath)
return &Manager{provider: provider, cookieName: cookieName, maxlifetime: maxlifetime, options: options}, nil hashfunc := "sha1"
if len(options) > 1 {
hashfunc = options[1].(string)
}
hashkey := "beegosessionkey"
if len(options) > 2 {
hashkey = options[2].(string)
}
return &Manager{
provider: provider,
cookieName: cookieName,
maxlifetime: maxlifetime,
hashfunc: hashfunc,
hashkey: hashkey,
options: options,
}, nil
} }
//get Session //get Session
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) {
cookie, err := r.Cookie(manager.cookieName) cookie, err := r.Cookie(manager.cookieName)
maxage := -1
if len(manager.options) > 3 {
switch manager.options[3].(type) {
case int:
if manager.options[3].(int) > 0 {
maxage = manager.options[3].(int)
} else if manager.options[3].(int) < 0 {
maxage = 0
}
case int64:
if manager.options[3].(int64) > 0 {
maxage = int(manager.options[3].(int64))
} else if manager.options[3].(int64) < 0 {
maxage = 0
}
case int32:
if manager.options[3].(int32) > 0 {
maxage = int(manager.options[3].(int32))
} else if manager.options[3].(int32) < 0 {
maxage = 0
}
}
}
if err != nil || cookie.Value == "" { if err != nil || cookie.Value == "" {
sid := manager.sessionId() sid := manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
secure := false secure := false
if len(manager.options) > 0 { if len(manager.options) > 0 {
secure = manager.options[0].(bool) secure = manager.options[0].(bool)
} }
cookie := http.Cookie{Name: manager.cookieName, cookie = &http.Cookie{Name: manager.cookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: secure} Secure: secure}
if maxage >= 0 {
cookie.MaxAge = maxage
}
//cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second) //cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second)
http.SetCookie(w, &cookie) http.SetCookie(w, cookie)
r.AddCookie(&cookie) r.AddCookie(cookie)
} else { } else {
//cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second) //cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second)
cookie.HttpOnly = true cookie.HttpOnly = true
cookie.Path = "/" cookie.Path = "/"
if maxage >= 0 {
cookie.MaxAge = maxage
}
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
sid, _ := url.QueryUnescape(cookie.Value) sid, _ := url.QueryUnescape(cookie.Value)
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
...@@ -103,10 +160,81 @@ func (manager *Manager) GC() { ...@@ -103,10 +160,81 @@ func (manager *Manager) GC() {
time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() }) time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() })
} }
func (manager *Manager) sessionId() string { func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
sid := manager.sessionId(r)
cookie, err := r.Cookie(manager.cookieName)
if err != nil && cookie.Value == "" {
//delete old cookie
session, _ = manager.provider.SessionRead(sid)
secure := false
if len(manager.options) > 0 {
secure = manager.options[0].(bool)
}
cookie = &http.Cookie{Name: manager.cookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: true,
Secure: secure,
}
} else {
oldsid, _ := url.QueryUnescape(cookie.Value)
session, _ = manager.provider.SessionRegenerate(oldsid, sid)
cookie.Value = url.QueryEscape(sid)
cookie.HttpOnly = true
cookie.Path = "/"
}
maxage := -1
if len(manager.options) > 3 {
switch manager.options[3].(type) {
case int:
if manager.options[3].(int) > 0 {
maxage = manager.options[3].(int)
} else if manager.options[3].(int) < 0 {
maxage = 0
}
case int64:
if manager.options[3].(int64) > 0 {
maxage = int(manager.options[3].(int64))
} else if manager.options[3].(int64) < 0 {
maxage = 0
}
case int32:
if manager.options[3].(int32) > 0 {
maxage = int(manager.options[3].(int32))
} else if manager.options[3].(int32) < 0 {
maxage = 0
}
}
}
if maxage >= 0 {
cookie.MaxAge = maxage
}
http.SetCookie(w, cookie)
r.AddCookie(cookie)
return
}
//remote_addr cruunixnano randdata
func (manager *Manager) sessionId(r *http.Request) (sid string) {
b := make([]byte, 24) b := make([]byte, 24)
if _, err := io.ReadFull(rand.Reader, b); err != nil { if _, err := io.ReadFull(rand.Reader, b); err != nil {
return "" return ""
} }
return base64.URLEncoding.EncodeToString(b) bs := base64.URLEncoding.EncodeToString(b)
sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs)
if manager.hashfunc == "md5" {
h := md5.New()
h.Write([]byte(bs))
sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil)))
} else if manager.hashfunc == "sha1" {
h := hmac.New(sha1.New, []byte(manager.hashkey))
fmt.Fprintf(h, "%s", sig)
sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil)))
} else {
h := hmac.New(sha1.New, []byte(manager.hashkey))
fmt.Fprintf(h, "%s", sig)
sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil)))
}
return
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment