Commit 443c77b3 authored by Penghui Liao's avatar Penghui Liao

support DB.BeginTx of golang 1.8

Signed-off-by: 's avatarPenghui Liao <liaoishere@gmail.com>
parent 0dff7717
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// +build go1.8
// Package orm provide ORM for MySQL/PostgreSQL/sqlite // Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage // Simple Usage
// //
...@@ -52,6 +54,7 @@ ...@@ -52,6 +54,7 @@
package orm package orm
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
...@@ -458,11 +461,15 @@ func (o *orm) Using(name string) error { ...@@ -458,11 +461,15 @@ func (o *orm) Using(name string) error {
// begin transaction // begin transaction
func (o *orm) Begin() error { func (o *orm) Begin() error {
return o.BeginTx(context.Background(), nil)
}
func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error {
if o.isTx { if o.isTx {
return ErrTxHasBegan return ErrTxHasBegan
} }
var tx *sql.Tx var tx *sql.Tx
tx, err := o.db.(txer).Begin() tx, err := o.db.(txer).BeginTx(ctx, opts)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"io" "io"
...@@ -150,6 +151,13 @@ func (d *dbQueryLog) Begin() (*sql.Tx, error) { ...@@ -150,6 +151,13 @@ func (d *dbQueryLog) Begin() (*sql.Tx, error) {
return tx, err return tx, err
} }
func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
a := time.Now()
tx, err := d.db.(txer).BeginTx(ctx, opts)
debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err)
return tx, err
}
func (d *dbQueryLog) Commit() error { func (d *dbQueryLog) Commit() error {
a := time.Now() a := time.Now()
err := d.db.(txEnder).Commit() err := d.db.(txEnder).Commit()
......
...@@ -12,10 +12,13 @@ ...@@ -12,10 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// +build go1.8
package orm package orm
import ( import (
"bytes" "bytes"
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
...@@ -452,9 +455,9 @@ func TestNullDataTypes(t *testing.T) { ...@@ -452,9 +455,9 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr))
throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr))
throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr))
throwFail(t, AssertIs((*d.TimePtr).Format(testTime), timePtr.Format(testTime))) throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime)))
throwFail(t, AssertIs((*d.DatePtr).Format(testDate), datePtr.Format(testDate))) throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate)))
throwFail(t, AssertIs((*d.DateTimePtr).Format(testDateTime), dateTimePtr.Format(testDateTime))) throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime)))
} }
func TestDataCustomTypes(t *testing.T) { func TestDataCustomTypes(t *testing.T) {
...@@ -1990,6 +1993,66 @@ func TestTransaction(t *testing.T) { ...@@ -1990,6 +1993,66 @@ func TestTransaction(t *testing.T) {
} }
func TestTransactionIsolationLevel(t *testing.T) {
// this test worked when database support transaction isolation level
if IsSqlite {
return
}
o1 := NewOrm()
o2 := NewOrm()
// start two transaction with isolation level repeatable read
err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
throwFail(t, err)
err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
throwFail(t, err)
// o1 insert tag
var tag Tag
tag.Name = "test-transaction"
id, err := o1.Insert(&tag)
throwFail(t, err)
throwFail(t, AssertIs(id > 0, true))
// o2 query tag table, no result
num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 0))
// o1 commit
o1.Commit()
// o2 query tag table, still no result
num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 0))
// o2 commit and query tag table, get the result
o2.Commit()
num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
func TestBeginTxWithContextCanceled(t *testing.T) {
o := NewOrm()
ctx, cancel := context.WithCancel(context.Background())
o.BeginTx(ctx, nil)
id, err := o.Insert(&Tag{Name: "test-context"})
throwFail(t, err)
throwFail(t, AssertIs(id > 0, true))
// cancel the context before commit to make it error
cancel()
err = o.Commit()
throwFail(t, AssertIs(err, context.Canceled))
}
func TestReadOrCreate(t *testing.T) { func TestReadOrCreate(t *testing.T) {
u := &User{ u := &User{
UserName: "Kyle", UserName: "Kyle",
...@@ -2260,6 +2323,7 @@ func TestIgnoreCaseTag(t *testing.T) { ...@@ -2260,6 +2323,7 @@ func TestIgnoreCaseTag(t *testing.T) {
throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name"))
throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name"))
} }
func TestInsertOrUpdate(t *testing.T) { func TestInsertOrUpdate(t *testing.T) {
RegisterModel(new(User)) RegisterModel(new(User))
user := User{UserName: "unique_username133", Status: 1, Password: "o"} user := User{UserName: "unique_username133", Status: 1, Password: "o"}
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"database/sql" "database/sql"
"reflect" "reflect"
"time" "time"
...@@ -106,6 +107,17 @@ type Ormer interface { ...@@ -106,6 +107,17 @@ type Ormer interface {
// ... // ...
// err = o.Rollback() // err = o.Rollback()
Begin() error Begin() error
// begin transaction with provided context and option
// the provided context is used until the transaction is committed or rolled back.
// if the context is canceled, the transaction will be rolled back.
// the provided TxOptions is optional and may be nil if defaults should be used.
// if a non-default isolation level is used that the driver doesn't support, an error will be returned.
// for example:
// o := NewOrm()
// err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
// ...
// err = o.Rollback()
BeginTx(ctx context.Context, opts *sql.TxOptions) error
// commit transaction // commit transaction
Commit() error Commit() error
// rollback transaction // rollback transaction
...@@ -401,6 +413,7 @@ type dbQuerier interface { ...@@ -401,6 +413,7 @@ type dbQuerier interface {
// transaction beginner // transaction beginner
type txer interface { type txer interface {
Begin() (*sql.Tx, error) Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
} }
// transaction ending // transaction ending
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment