diff --git a/RequestBuilder.cs b/RequestBuilder.cs index 56226fd72..253493383 100644 --- a/RequestBuilder.cs +++ b/RequestBuilder.cs @@ -10,6 +10,7 @@ using Newtonsoft.Json; using System.IO; using System.Web; +using System.Threading; namespace Refit { @@ -84,6 +85,165 @@ public Func BuildRequestFactoryForMethod(string me return ret; }; } + + public Func BuildRestResultFuncForMethod(string methodName) + { + if (!interfaceHttpMethods.ContainsKey(methodName)) { + throw new ArgumentException("Method must be defined and have an HTTP Method attribute"); + } + + var restMethod = interfaceHttpMethods[methodName]; + + if (restMethod.ReturnType == typeof(Task)) { + return buildVoidTaskFuncForMethod(restMethod); + } else if (restMethod.ReturnType.GetGenericTypeDefinition() == typeof(Task<>)) { + return buildTaskFuncForMethod(restMethod); + } else { + return buildRxFuncForMethod(restMethod); + } + } + + Func buildVoidTaskFuncForMethod(RestMethodInfo restMethod) + { + var factory = BuildRequestFactoryForMethod(restMethod.Name); + + return async (client, paramList) => { + var rq = factory(paramList); + var resp = await client.SendAsync(rq); + + resp.EnsureSuccessStatusCode(); + }; + } + + Func> buildTaskFuncForMethod(RestMethodInfo restMethod) + { + var factory = BuildRequestFactoryForMethod(restMethod.Name); + + return async (client, paramList) => { + var rq = factory(paramList); + var resp = await client.SendAsync(rq); + if (restMethod.SerializedReturnType == null) { + return resp; + } + + resp.EnsureSuccessStatusCode(); + + var content = await resp.Content.ReadAsStringAsync(); + if (restMethod.SerializedReturnType == typeof(string)) { + return content; + } + + return JsonConvert.DeserializeObject(content, restMethod.SerializedReturnType); + }; + } + + Func> buildRxFuncForMethod(RestMethodInfo restMethod) + { + var taskFunc = buildTaskFuncForMethod(restMethod); + + return (client, paramList) => { + var ret = new FakeAsyncSubject(); + + taskFunc(client, paramList).ContinueWith(t => { + if (t.Exception != null) { + ret.OnError(t.Exception); + } else { + ret.OnNext(t.Result); + ret.OnCompleted(); + } + }); + + return ret; + }; + } + + class CompletionResult + { + public bool IsCompleted { get; set; } + public Exception Error { get; set; } + } + + class FakeAsyncSubject : IObservable, IObserver + { + bool resultSet; + T result; + CompletionResult completion; + List> subscriberList = new List>(); + + public void OnNext(T value) + { + if (completion == null) return; + + result = value; + resultSet = true; + + var currentList = default(IObserver[]); + lock (subscriberList) { currentList = subscriberList.ToArray(); } + foreach (var v in currentList) v.OnNext(value); + } + + public void OnError(Exception error) + { + var final = Interlocked.CompareExchange(ref completion, new CompletionResult() { IsCompleted = false, Error = error }, null); + if (final.IsCompleted) return; + + var currentList = default(IObserver[]); + lock (subscriberList) { currentList = subscriberList.ToArray(); } + foreach (var v in currentList) v.OnError(error); + + final.IsCompleted = true; + } + + public void OnCompleted() + { + var final = Interlocked.CompareExchange(ref completion, new CompletionResult() { IsCompleted = false, Error = null }, null); + if (final.IsCompleted) return; + + var currentList = default(IObserver[]); + lock (subscriberList) { currentList = subscriberList.ToArray(); } + foreach (var v in currentList) v.OnCompleted(); + + final.IsCompleted = true; + } + + public IDisposable Subscribe(IObserver observer) + { + if (completion != null) { + if (completion.Error != null) { + observer.OnError(completion.Error); + return new AnonymousDisposable(() => {}); + } + + if (resultSet) observer.OnNext(result); + observer.OnCompleted(); + + return new AnonymousDisposable(() => {}); + } + + lock (subscriberList) { + subscriberList.Add(observer); + } + + return new AnonymousDisposable(() => { + lock (subscriberList) { subscriberList.Remove(observer); } + }); + } + } + } + + sealed class AnonymousDisposable : IDisposable + { + readonly Action block; + + public AnonymousDisposable(Action block) + { + this.block = block; + } + + public void Dispose() + { + block(); + } } class RestMethodInfo @@ -96,6 +256,8 @@ class RestMethodInfo public Dictionary ParameterMap { get; set; } public Tuple BodyParameterInfo { get; set; } public Dictionary QueryParameterMap { get; set; } + public Type ReturnType { get; set; } + public Type SerializedReturnType { get; set; } static readonly Regex parameterRegex = new Regex(@"^{(.*)}$"); @@ -113,6 +275,7 @@ public RestMethodInfo(Type targetInterface, MethodInfo methodInfo) RelativePath = hma.Path; verifyUrlPathIsSane(RelativePath); + determineReturnTypeInfo(methodInfo); var parameterList = methodInfo.GetParameters().ToList(); @@ -196,6 +359,26 @@ Tuple findBodyParameter(List parame return Tuple.Create(ret.BodyAttribute.SerializationMethod, parameterList.IndexOf(ret.Parameter)); } + + void determineReturnTypeInfo(MethodInfo methodInfo) + { + if (methodInfo.ReturnType.IsGenericType == false && methodInfo.ReturnType != typeof(Task)) { + goto bogusMethod; + } + + var genericType = methodInfo.ReturnType.GetGenericTypeDefinition(); + if (genericType != typeof(Task<>) && genericType != typeof(IObservable<>)) { + goto bogusMethod; + } + + ReturnType = methodInfo.ReturnType; + SerializedReturnType = methodInfo.ReturnType.GetGenericArguments()[0]; + if (SerializedReturnType == typeof(HttpResponseMessage)) SerializedReturnType = null; + return; + + bogusMethod: + throw new ArgumentException("All REST Methods must return either Task or IObservable"); + } } /* @@ -223,7 +406,10 @@ public interface IRestMethodInfoTests Task FetchSomeStuffWithAlias([AliasAs("id")] int anId); [Get("/foo/bar/{id}")] - Task FetchSomeStuffWithBody([AliasAs("id")] int anId, [Body] Dictionary theData); + IObservable FetchSomeStuffWithBody([AliasAs("id")] int anId, [Body] Dictionary theData); + + [Post("/foo/{id}")] + string AsyncOnlyBuddy(int id); } [TestFixture] @@ -310,6 +496,21 @@ public void FindTheBodyParameter() Assert.AreEqual(0, fixture.QueryParameterMap.Count); Assert.AreEqual(1, fixture.BodyParameterInfo.Item2); } + + [Test] + public void SyncMethodsShouldThrow() + { + bool shouldDie = true; + + try { + var input = typeof(IRestMethodInfoTests); + var fixture = new RestMethodInfo(input, input.GetMethods().First(x => x.Name == "AsyncOnlyBuddy")); + } catch (ArgumentException) { + shouldDie = false; + } + + Assert.IsFalse(shouldDie); + } } public interface IDummyHttpApi