diff --git a/errors.go b/errors.go index add5dbe..7df308a 100644 --- a/errors.go +++ b/errors.go @@ -69,6 +69,10 @@ func (e NoTransitionError) Error() string { return "no transition" } +func (e NoTransitionError) Unwrap() error { + return e.Err +} + // CanceledError is returned by FSM.Event() when a callback have canceled a // transition. type CanceledError struct { @@ -82,6 +86,10 @@ func (e CanceledError) Error() string { return "transition canceled" } +func (e CanceledError) Unwrap() error { + return e.Err +} + // AsyncError is returned by FSM.Event() when a callback have initiated an // asynchronous state transition. type AsyncError struct { @@ -98,6 +106,10 @@ func (e AsyncError) Error() string { return "async started" } +func (e AsyncError) Unwrap() error { + return e.Err +} + // InternalError is returned by FSM.Event() and should never occur. It is a // probably because of a bug. type InternalError struct{} diff --git a/errors_test.go b/errors_test.go index ba384ee..31637d0 100644 --- a/errors_test.go +++ b/errors_test.go @@ -60,6 +60,12 @@ func TestNoTransitionError(t *testing.T) { if e.Error() != "no transition with error: "+e.Err.Error() { t.Error("NoTransitionError string mismatch") } + if e.Unwrap() == nil { + t.Error("CanceledError Unwrap() should not be nil") + } + if !errors.Is(e, e.Err) { + t.Error("CanceledError should be equal to its error") + } } func TestCanceledError(t *testing.T) { @@ -71,6 +77,12 @@ func TestCanceledError(t *testing.T) { if e.Error() != "transition canceled with error: "+e.Err.Error() { t.Error("CanceledError string mismatch") } + if e.Unwrap() == nil { + t.Error("CanceledError Unwrap() should not be nil") + } + if !errors.Is(e, e.Err) { + t.Error("CanceledError should be equal to its error") + } } func TestAsyncError(t *testing.T) { @@ -82,6 +94,12 @@ func TestAsyncError(t *testing.T) { if e.Error() != "async started with error: "+e.Err.Error() { t.Error("AsyncError string mismatch") } + if e.Unwrap() == nil { + t.Error("AsyncError Unwrap() should not be nil") + } + if !errors.Is(e, e.Err) { + t.Error("AsyncError should be equal to its error") + } } func TestInternalError(t *testing.T) { diff --git a/fsm.go b/fsm.go index bd269d7..f3af5f2 100644 --- a/fsm.go +++ b/fsm.go @@ -228,6 +228,8 @@ func (f *FSM) SetState(state string) { // Can returns true if event can occur in the current state. func (f *FSM) Can(event string) bool { + f.eventMu.Lock() + defer f.eventMu.Unlock() f.stateMu.RLock() defer f.stateMu.RUnlock() _, ok := f.transitions[eKey{event, f.current}] @@ -333,6 +335,10 @@ func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) erro } if f.current == dst { + f.stateMu.RUnlock() + defer f.stateMu.RLock() + f.eventMu.Unlock() + unlocked = true f.afterEventCallbacks(ctx, e) return NoTransitionError{e.Err} } diff --git a/fsm_test.go b/fsm_test.go index 8e0c4ed..6a8dadf 100644 --- a/fsm_test.go +++ b/fsm_test.go @@ -825,6 +825,31 @@ func TestNoTransition(t *testing.T) { } } +func TestNoTransitionAfterEventCallbackTransition(t *testing.T) { + var fsm *FSM + fsm = NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "start"}, + {Name: "finish", Src: []string{"start"}, Dst: "finished"}, + }, + Callbacks{ + "after_event": func(_ context.Context, e *Event) { + fsm.Event(context.Background(), "finish") + }, + }, + ) + err := fsm.Event(context.Background(), "run") + if _, ok := err.(NoTransitionError); !ok { + t.Error("expected 'NoTransitionError'") + } + + currentState := fsm.Current() + if currentState != "finished" { + t.Errorf("expected state to be 'finished', was '%s'", currentState) + } +} + func ExampleNewFSM() { fsm := NewFSM( "green", @@ -1013,3 +1038,31 @@ func ExampleFSM_Transition() { // closed // open } + +func TestEventAndCanInGoroutines(t *testing.T) { + fsm := NewFSM( + "closed", + Events{ + {Name: "open", Src: []string{"closed"}, Dst: "open"}, + {Name: "close", Src: []string{"open"}, Dst: "closed"}, + }, + Callbacks{}, + ) + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(2) + go func(n int) { + defer wg.Done() + if n%2 == 0 { + _ = fsm.Event(context.Background(), "open") + } else { + _ = fsm.Event(context.Background(), "close") + } + }(i) + go func() { + defer wg.Done() + fsm.Can("close") + }() + } + wg.Wait() +}