diff --git a/.skills/csharp-async-best-practices/SKILL.md b/.skills/csharp-async-best-practices/SKILL.md new file mode 100644 index 0000000000..fd4f6fa159 --- /dev/null +++ b/.skills/csharp-async-best-practices/SKILL.md @@ -0,0 +1,120 @@ +--- +name: csharp-async-best-practices +description: Use when reviewing, writing, refactoring, or designing c# async code that uses task, task-generic, valuetask, cancellationtoken, task.whenall, task.whenany, task.run, configureawait, async void, or fire-and-forget patterns. Trigger on `.result`, `.wait()`, deadlocks, cancellation propagation, asp.net core background work, ui responsiveness, exception flow, and performance-sensitive async api design. +metadata: + category: technique + triggers: + - c# + - async + - task + - valuetask + - cancellationtoken + - configureawait + - .result + - .wait() + - async void + - fire-and-forget + - task.run + - whenall + - whenany + - asp.net core + - deadlock +--- + +# C# Async Best Practices + +## Overview + +Apply evidence-backed async guidance with this priority order: + +1. correctness and cancellation semantics +2. context-specific API design +3. concurrency behavior and failure handling +4. performance tuning only when the hot path is real + +Treat blanket advice as suspect. Separate official behavior from expert interpretation and from your own recommendation. + +## Workflow + +1. Classify the code before judging it. + - **I/O-bound async**: network, file, database, timers, async waits + - **CPU-bound work**: expensive computation + - **Context**: library, UI app, ASP.NET Core app, background service, test code + - **Pressure**: hot path or ordinary path +2. Prefer the least surprising correct design. +3. Only optimize allocations or scheduling after the correctness story is sound. +4. Load the matching reference file before making strong claims. + - **General rules and code review defaults**: `references/core-guidance.md` + - **Context-sensitive rules**: `references/context-and-tradeoffs.md` + - **Source notes and authority breakdown**: `references/source-notes.md` + +## Review defaults + +Start from these defaults unless the case-specific evidence says otherwise: + +| Topic | Default judgment | +|---|---| +| Blocking on async | usually a defect or interop boundary smell | +| `async void` | only acceptable for event handlers | +| `ValueTask` | avoid by default; justify with measurements or a very hot path | +| `ConfigureAwait(false)` | good library default, not an app-wide default | +| `Task.Run` | use to offload CPU work when needed, not to fake async I/O | +| Fire-and-forget | assume unsafe until lifecycle, scope, and exception handling are explicit | +| `Task.WhenAll` | prefer for independent concurrent operations | +| `Task.WhenAny` | always inspect winner and define what happens to losers | +| Cancellation | accept and propagate token until the point of no cancellation | + +## Output contract + +When you review or design code, label your reasoning like this: + +- **Fact**: official runtime or API behavior +- **Expert guidance**: interpretation from strong experts when it adds design meaning +- **Synthesis**: your recommendation for this exact case + +Do not present contextual advice as a universal law. + +## Common traps + +- Calling `.Result`, `.Wait()`, or `GetAwaiter().GetResult()` inside normal async-capable code +- Recommending `ConfigureAwait(false)` everywhere because “it is .NET Core” or “it prevents deadlocks” +- Recommending `Task.Run` inside ASP.NET Core request code just to make code “more async” +- Recommending `ValueTask` for every hot-looking method without checking completion behavior, call frequency, or single-consumer assumptions +- Ignoring cancellation after plumbing a `CancellationToken` +- Using `Task.WhenAny` without awaiting the returned winner task or handling the remaining tasks +- Treating fire-and-forget as harmless when it touches scoped services, `HttpContext`, or unobserved failures + +## Rationalization traps + +| Rationalization | Better reasoning | +|---|---| +| “It works, so `.Result` is fine.” | Lack of failure under one context does not make blocking safe or scalable. | +| “`ValueTask` is always faster.” | It trades simplicity for niche allocation wins and stricter consumption rules. | +| “`ConfigureAwait(false)` everywhere is modern guidance.” | Library and app code have different constraints. Blanket rules are weak. | +| “`Task.Run` makes server code asynchronous.” | It only queues work; it does not turn blocking I/O into true async I/O. | +| “Fire-and-forget is okay because logging exists.” | Logging does not solve scope lifetime, shutdown, retries, or error propagation. | + +## Deliverable shape + +For code review or implementation help, prefer: + +1. a short context classification +2. the concrete problem +3. the corrected pattern +4. the context-dependent tradeoff, if any +5. the smallest safe code change + +## API shape and testability + +- Prefer `Async` suffixes for awaitable-returning methods unless an established contract or event pattern dictates otherwise. +- Prefer `Task`-returning seams over hidden background work so tests can await completion, faults, and cancellation. +- For timers, queues, retries, or background pipelines, recommend abstractions that let tests control time and observe completion. +- When reviewing an async API, ask whether callers can compose it, cancel it, await it, and assert its failure behavior. + +## Hard boundaries + +- Do not endorse sync-over-async as a normal design choice. +- Do not suggest `async void` except for event handlers. +- Do not suggest `ValueTask` unless the constraints are understood. +- Do not claim `ConfigureAwait(false)` is always needed or always unnecessary. +- Do not approve fire-and-forget unless ownership, exception handling, and lifetime are explicit. diff --git a/.skills/csharp-async-best-practices/references/context-and-tradeoffs.md b/.skills/csharp-async-best-practices/references/context-and-tradeoffs.md new file mode 100644 index 0000000000..f6bdaedb79 --- /dev/null +++ b/.skills/csharp-async-best-practices/references/context-and-tradeoffs.md @@ -0,0 +1,82 @@ +--- +description: >- + Context-specific async guidance for library code, ui apps, asp.net core, + background work, task.run, configureawait, and performance-sensitive design. +metadata: + tags: [configureawait, task.run, asp.net core, ui, library, performance] + source: mixed +--- + +# Context and Tradeoffs + +## Library code versus app code + +### General-purpose library code +- Prefer APIs that expose true async for I/O-bound work. +- Do not add async wrappers around purely compute-bound methods just to look modern. Expose sync compute APIs and let callers decide whether to offload. +- `ConfigureAwait(false)` is a strong default when the library does not need the caller’s context. +- Avoid ambient assumptions about a UI thread, request context, or test framework behavior. + +### App code +- Prefer the style that fits the app model. +- UI code often needs the original context after `await`. +- ASP.NET Core request code normally does not need `Task.Run` just to stay responsive, because it already runs on thread pool threads. +- Do not present “ASP.NET Core has no synchronization context” as proof that every `ConfigureAwait(false)` discussion is obsolete. + +## `Task.Run` boundaries + +### Good uses +- Offload CPU-bound work so a UI thread can stay responsive. +- Offload CPU work from a caller when that scheduling boundary is deliberate. + +### Weak uses +- Wrapping synchronous I/O to pretend it is true async I/O. +- Calling `Task.Run` and immediately awaiting it in ASP.NET Core request handling when no CPU offload goal exists. +- Using `Task.Run` to hide blocking APIs instead of fixing the underlying API choice. + +## Fire-and-forget + +### Assume unsafe until proven otherwise +A background task needs answers for all of these: +- Who owns its lifetime? +- How are exceptions observed? +- How does shutdown cancel it? +- Does it touch scoped services or request-bound objects? +- Does work need retries, backpressure, or queueing? + +### Safer alternatives +- Await the task normally. +- Queue work to an owned background component. +- In ASP.NET Core, prefer hosted services or a dedicated background queue pattern for long-lived work. +- If scoped services are required in background processing, create an explicit scope instead of capturing request scope objects. + +## `ConfigureAwait` + +### Strong recommendation +- In general-purpose libraries, use `ConfigureAwait(false)` unless the continuation must run in the captured context. + +### Weak recommendation +- “Always use it in app code.” +- “Never use it on .NET Core.” +- “Use it once at the first await and you are done.” + +### Review note +If code after the `await` needs a specific context, say so explicitly. If it does not, the recommendation depends on whether the code is app-level or general-purpose library code. + +## Performance guidance + +### Correctness first +Do not trade API clarity for speculative micro-optimizations. + +### `ValueTask` is performance-specialized +Recommend it only when most of these are true: +1. the method is called very frequently +2. it often completes synchronously or from a reusable source +3. allocation reduction matters on measurements +4. consumers can respect single-consumer semantics +5. task combinator ergonomics are not central to the API + +### Throttling and concurrency control +- `Task.WhenAll` expresses concurrency; it does not limit it. +- For bounded concurrency, use an async gate such as `SemaphoreSlim.WaitAsync`, or platform helpers such as `Parallel.ForEachAsync` when the workload fits. +- Always define what happens to remaining work after the first completion or first failure. diff --git a/.skills/csharp-async-best-practices/references/core-guidance.md b/.skills/csharp-async-best-practices/references/core-guidance.md new file mode 100644 index 0000000000..114ad50d0e --- /dev/null +++ b/.skills/csharp-async-best-practices/references/core-guidance.md @@ -0,0 +1,105 @@ +--- +description: >- + Source-backed core guidance for task, valuetask, cancellation, exception flow, + blocking, and concurrency in c# async code reviews and implementations. +metadata: + tags: [csharp, async, task, valuetask, cancellation, exceptions, concurrency] + source: mixed +--- + +# Core Guidance + +## Facts from official .NET documentation + +### 1. Return types and `async void` +- Async methods should normally return `Task` or `Task`. +- `async void` is intended for event handlers; callers cannot await it and exception handling differs. +- TAP methods that return awaitable types conventionally use the `Async` suffix. + +### 2. Blocking on async +- `Task.Result` is blocking. Prefer `await` in most cases. +- Blocking can deadlock in context-bound environments and reduces scalability even when it does not deadlock. +- `await` on a faulted task rethrows one exception directly; `.Wait()` and `.Result` wrap failures in `AggregateException`. + +### 3. `Task` versus `ValueTask` +- Default to `Task` or `Task` unless there is a demonstrated reason not to. +- `ValueTask` has stricter usage rules. A given instance should generally be awaited only once. +- Do not await the same `ValueTask` multiple times, call `AsTask()` multiple times, or mix consumption techniques on the same instance. +- For synchronously successful `Task`-returning methods, `Task.CompletedTask` is the normal zero-result completion value. + +### 4. Cancellation +- If a TAP method supports cancellation, expose a `CancellationToken`. +- Pass the token to nested operations that should participate in cancellation. +- If an async method throws `OperationCanceledException` associated with the method’s token, the returned task transitions to `Canceled`. +- After a method has completed its work successfully, do not report cancellation instead of success. + +### 5. Exception flow and task combinators +- `Task.WhenAll` does not block the calling thread. +- If any supplied task faults, the `WhenAll` task faults and aggregates the unwrapped exceptions from the component tasks. +- If none fault and at least one is canceled, the `WhenAll` task is canceled. +- `Task.WhenAny` returns a task that completes successfully with the first completed task as its result, even when that winning task itself is faulted or canceled. +- After `WhenAny`, await the returned winner task to propagate its outcome. +- The remaining tasks continue unless you cancel or otherwise handle them. + +## Expert guidance that is strong and technically grounded + +### Stephen Toub +- Use `ConfigureAwait(false)` as the general default for general-purpose library code, because library code should not depend on an app model’s context. +- App-level code is different. UI code often needs the captured context. ASP.NET Core also changes the deadlock discussion because it does not install the classic ASP.NET style synchronization context, but that does not make blanket `ConfigureAwait` advice strong. +- `ValueTask` exists mainly to avoid allocations on frequently synchronous success paths. It is not a general replacement for `Task` because `Task` is more flexible for multiple awaits, caching, and combinators. + +### Andrew Arnott +- Propagate the token until the point of no cancellation. +- Validate arguments before cancellation checks when argument validation should always run. +- Prefer catching `OperationCanceledException` rather than `TaskCanceledException` in general-purpose logic. +- Keep `CancellationToken` last in the parameter list; make it optional mainly on public APIs, not necessarily on internal methods. + +### Stephen Cleary +- “Async all the way” is a strong design guideline, not an absolute law of physics. Sync bridges exist, but they are specialized boundary decisions, not a normal code review recommendation. +- `async void` and sync-over-async both create real observability and composition problems even when a sample appears to work. + +## Naming and testability + +### Naming +- TAP methods that return awaitable types conventionally use the `Async` suffix. Do not force renames when an interface, base class, or event pattern already dictates the name. + +### Testability +- Favor awaitable APIs over hidden work so tests can await completion, assert faults, and drive cancellation deterministically. +- Prefer explicit background components, injected clocks, and owned queues over ad hoc fire-and-forget logic that tests cannot observe. + +## Synthesis for agents + +### Code review defaults +- Treat `.Result`, `.Wait()`, and `GetAwaiter().GetResult()` as likely defects unless the code is a deliberate sync boundary and the caller explicitly cannot be async. +- Prefer `Task`/`Task` for API design. Require an explicit reason before recommending `ValueTask`. +- Require cancellation behavior to be coherent: accepted, propagated, and not silently dropped. +- Prefer `await Task.WhenAll(...)` for independent operations started before awaiting. +- Treat `Task.WhenAny(...)` as incomplete until the winner is awaited and losers are canceled, observed, or intentionally left running. + +### Minimal examples + +#### Avoid sync-over-async +```csharp +// bad +var user = client.GetUserAsync(id).Result; + +// better +var user = await client.GetUserAsync(id); +``` + +#### Use `Task.WhenAll` for parallel I/O +```csharp +var userTask = repo.GetUserAsync(id, ct); +var ordersTask = repo.GetOrdersAsync(id, ct); +await Task.WhenAll(userTask, ordersTask); +return new Dashboard(await userTask, await ordersTask); +``` + +#### Be conservative with `ValueTask` +```csharp +// default +Task GetAsync(string key, CancellationToken ct); + +// specialized hot path only when justified +ValueTask TryGetCachedAsync(string key); +``` diff --git a/.skills/csharp-async-best-practices/references/source-notes.md b/.skills/csharp-async-best-practices/references/source-notes.md new file mode 100644 index 0000000000..0d34589d1d --- /dev/null +++ b/.skills/csharp-async-best-practices/references/source-notes.md @@ -0,0 +1,59 @@ +--- +description: >- + Authority notes and citations for the c# async best practices skill, separating + official documentation, expert interpretation, and synthesized guidance. +metadata: + tags: [sources, citations, authority, notes] + source: external +--- + +# Source Notes + +## Official facts + +- Microsoft Learn, "Implementing the Task-based Asynchronous Pattern" + - https://learn.microsoft.com/en-us/dotnet/standard/asynchronous-programming-patterns/implementing-the-task-based-asynchronous-pattern + - Return types, cancellation behavior, `Task.Run` boundaries, and TAP implementation guidance. +- Microsoft Learn, "Consuming the Task-based Asynchronous Pattern" + - https://learn.microsoft.com/en-us/dotnet/standard/asynchronous-programming-patterns/consuming-the-task-based-asynchronous-pattern + - `await`, `WhenAll`, `WhenAny`, cancellation propagation, and exception behavior. +- Microsoft Learn, "Async return types" + - https://learn.microsoft.com/en-us/dotnet/csharp/asynchronous-programming/async-return-types + - `Task`, `Task`, `async void`, generalized async return types. +- Microsoft Learn, `ValueTask` API reference + - https://learn.microsoft.com/en-us/dotnet/api/system.threading.tasks.valuetask + - single-consumer warnings and default-to-`Task` guidance. +- Microsoft Learn, ASP.NET Core best practices + - https://learn.microsoft.com/en-us/aspnet/core/fundamentals/best-practices + - avoid blocking calls, avoid unnecessary `Task.Run`, background-work cautions. +- Microsoft Learn, hosted services in ASP.NET Core + - https://learn.microsoft.com/en-us/aspnet/core/fundamentals/host/hosted-services + - safe long-lived background work and cancellation during shutdown. + +## Expert guidance used only when technically grounded + +- Stephen Toub, ".NET Blog: ConfigureAwait FAQ" + - https://devblogs.microsoft.com/dotnet/configureawait-faq/ + - best source for context capture semantics and library-vs-app guidance. +- Stephen Toub, ".NET Blog: Understanding the Whys, Whats, and Whens of ValueTask" + - https://devblogs.microsoft.com/dotnet/understanding-the-whys-whats-and-whens-of-valuetask/ + - performance rationale and tradeoffs behind `ValueTask`. +- Stephen Toub, ".NET Blog: Await, and UI, and deadlocks! Oh my!" + - https://devblogs.microsoft.com/dotnet/await-and-ui-and-deadlocks-oh-my/ + - canonical deadlock explanation for context-bound code. +- Stephen Toub, ".NET Blog: Task Exception Handling in .NET 4.5" + - https://devblogs.microsoft.com/dotnet/task-exception-handling-in-net-4-5/ + - explains `await` versus blocking exception shape and why `WhenAll` matters. +- Andrew Arnott, "Recommended patterns for CancellationToken" + - https://devblogs.microsoft.com/premier-developer/recommended-patterns-for-cancellationtoken/ + - practical cancellation design heuristics; useful, but not treated as a language/runtime spec. +- Stephen Cleary, "Async/Await - Best Practices in Asynchronous Programming" + - https://learn.microsoft.com/en-us/archive/msdn-magazine/2013/march/async-await-best-practices-in-asynchronous-programming + - useful design interpretation, but older and treated as contextual guidance rather than current official policy. + +## Where the skill is intentionally cautious + +- `ConfigureAwait`: strong guidance exists for libraries, weaker guidance for app code. Blanket rules are rejected. +- `Task.Run`: valid for deliberate CPU offload, weak as a server-side patch for blocking I/O. +- `ValueTask`: supported and useful, but easy to misuse. The skill defaults to `Task` unless evidence is present. +- Fire-and-forget: acceptable only with explicit ownership and lifecycle design, especially in server code. diff --git a/DebugTools/MccMcpStdioHarness/Program.cs b/DebugTools/MccMcpStdioHarness/Program.cs index 1b90ddbb40..bb085f2772 100644 --- a/DebugTools/MccMcpStdioHarness/Program.cs +++ b/DebugTools/MccMcpStdioHarness/Program.cs @@ -567,6 +567,9 @@ public MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool timeoutMs }); + public Task MoveToAsync(double x, double y, double z, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) => + Task.FromResult(MoveTo(x, y, z, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs)); + public MccMcpResult MoveToPlayer(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) => MccMcpResult.Ok(new { @@ -593,6 +596,9 @@ public MccMcpResult MoveToPlayer(string playerName, bool allowUnsafe, bool allow timeoutMs }); + public Task MoveToPlayerAsync(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) => + Task.FromResult(MoveToPlayer(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs)); + public MccMcpResult LookAt(double x, double y, double z) => MccMcpResult.Ok(new { looked = true, x = C(x), y = C(y), z = C(z) }); @@ -755,6 +761,9 @@ public MccMcpResult OpenContainerAt(int x, int y, int z, int timeoutMs, bool clo }); } + public Task OpenContainerAtAsync(int x, int y, int z, int timeoutMs, bool closeCurrent) => + Task.FromResult(OpenContainerAt(x, y, z, timeoutMs, closeCurrent)); + public MccMcpResult CloseContainer(int inventoryId, int timeoutMs) { int resolvedInventoryId = inventoryId <= 0 ? 1 : inventoryId; @@ -768,6 +777,9 @@ public MccMcpResult CloseContainer(int inventoryId, int timeoutMs) }); } + public Task CloseContainerAsync(int inventoryId, int timeoutMs) => + Task.FromResult(CloseContainer(inventoryId, timeoutMs)); + public MccMcpResult InventoryWindowAction(int inventoryId, int slotId, string actionType) => MccMcpResult.Ok(new { success = true, inventoryId, slotId, actionType }); @@ -963,6 +975,9 @@ public MccMcpResult PickupItems(string itemType, double radius, int maxItems, bo } }); + public Task PickupItemsAsync(string itemType, double radius, int maxItems, bool allowUnsafe, int timeoutMs) => + Task.FromResult(PickupItems(itemType, radius, maxItems, allowUnsafe, timeoutMs)); + public MccMcpResult Respawn() { health = 20.0f; diff --git a/MinecraftClient/AutoTimeout.cs b/MinecraftClient/AutoTimeout.cs index 786be4db13..f08dc5e590 100644 --- a/MinecraftClient/AutoTimeout.cs +++ b/MinecraftClient/AutoTimeout.cs @@ -1,5 +1,6 @@ using System; using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient { @@ -22,6 +23,11 @@ public static bool Perform(Action action, int timeout) return Perform(action, TimeSpan.FromMilliseconds(timeout)); } + public static Task PerformAsync(Action action, int timeout, CancellationToken cancellationToken = default) + { + return PerformAsync(action, TimeSpan.FromMilliseconds(timeout), cancellationToken); + } + /// /// Perform the specified action with specified timeout /// @@ -30,14 +36,26 @@ public static bool Perform(Action action, int timeout) /// True if the action finished whithout timing out public static bool Perform(Action action, TimeSpan timeout) { - Thread thread = new(new ThreadStart(action)); - thread.Start(); + return PerformAsync(action, timeout).GetAwaiter().GetResult(); + } - bool success = thread.Join(timeout); - if (!success) - thread.Interrupt(); + public static async Task PerformAsync(Action action, TimeSpan timeout, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(action); - return success; + try + { + await Task.Run(action, cancellationToken).WaitAsync(timeout, cancellationToken); + return true; + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + return false; + } + catch (TimeoutException) + { + return false; + } } } -} \ No newline at end of file +} diff --git a/MinecraftClient/ClassicConsoleBackend.cs b/MinecraftClient/ClassicConsoleBackend.cs index ac4912c914..bc8dedd822 100644 --- a/MinecraftClient/ClassicConsoleBackend.cs +++ b/MinecraftClient/ClassicConsoleBackend.cs @@ -1,4 +1,6 @@ using System; +using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient { @@ -50,6 +52,13 @@ public string RequestImmediateInput() return ConsoleInteractive.ConsoleReader.RequestImmediateInput(); } + public Task RequestImmediateInputAsync(CancellationToken cancellationToken) + { + // ConsoleInteractive only exposes a blocking immediate-read API. + // Keep the compatibility boundary here so the wider startup/runtime path can await it. + return Task.Run(ConsoleInteractive.ConsoleReader.RequestImmediateInput, cancellationToken); + } + public string? ReadPassword() { ConsoleInteractive.ConsoleReader.SetInputVisible(false); @@ -58,6 +67,19 @@ public string RequestImmediateInput() return input; } + public async Task ReadPasswordAsync(CancellationToken cancellationToken) + { + ConsoleInteractive.ConsoleReader.SetInputVisible(false); + try + { + return await RequestImmediateInputAsync(cancellationToken); + } + finally + { + ConsoleInteractive.ConsoleReader.SetInputVisible(true); + } + } + public void ClearInputBuffer() { ConsoleInteractive.ConsoleReader.ClearBuffer(); diff --git a/MinecraftClient/ConsoleIO.cs b/MinecraftClient/ConsoleIO.cs index 8a77aaeefa..8245af1324 100644 --- a/MinecraftClient/ConsoleIO.cs +++ b/MinecraftClient/ConsoleIO.cs @@ -76,6 +76,13 @@ public static void SetAutoCompleteEngine(IAutoComplete engine) return Backend.ReadPassword(); } + public static Task ReadPasswordAsync(CancellationToken cancellationToken = default) + { + if (BasicIO) + return Task.FromResult(Console.ReadLine()); + return Backend.ReadPasswordAsync(cancellationToken); + } + /// /// Read a line from the standard input /// @@ -86,6 +93,13 @@ public static string ReadLine() return Backend.RequestImmediateInput(); } + public static Task ReadLineAsync(CancellationToken cancellationToken = default) + { + if (BasicIO) + return Task.FromResult(Console.ReadLine() ?? string.Empty); + return Backend.RequestImmediateInputAsync(cancellationToken); + } + /// /// Debug routine: print all keys pressed in the console /// @@ -233,84 +247,120 @@ private static void MccAutocompleteHandler(ConsoleInputBuffer buffer) DoClearSuggestions(); return; } + _cancellationTokenSource?.Cancel(); - using var cts = new CancellationTokenSource(); + var cts = new CancellationTokenSource(); _cancellationTokenSource = cts; - var previousTask = _latestTask; - var newTask = new Task(async () => + Task newTask = UpdateSuggestionsAsync(fullCommand, offset, buffer.CursorPosition, cts.Token); + _latestTask = newTask; + _ = ObserveAutocompleteTaskAsync(newTask, cts); + } + else + { + DoClearSuggestions(); + return; + } + } + + private static async Task UpdateSuggestionsAsync(string fullCommand, int offset, int cursorPosition, CancellationToken cancellationToken) + { + string command = fullCommand[offset..]; + if (command.Length == 0) + { + List suggestionList = new() { - string command = fullCommand[offset..]; - if (command.Length == 0) - { - List sugList = new(); + new("/") + }; - sugList.Add(new("/")); + var childs = McClient.dispatcher.GetRoot().Children; + if (childs is not null) + { + foreach (var child in childs) + suggestionList.Add(new(child.Name)); + } - var childs = McClient.dispatcher.GetRoot().Children; - if (childs is not null) - foreach (var child in childs) - sugList.Add(new(child.Name)); + foreach (var cmd in Commands) + suggestionList.Add(new(cmd)); - foreach (var cmd in Commands) - sugList.Add(new(cmd)); + if (cancellationToken.IsCancellationRequested) + return; - SendSuggestions(sugList.ToArray(), new(offset, offset)); - } - else if (command.Length > 0 && command[0] == '/' && !command.Contains(' ')) - { - var sorted = Process.ExtractSorted(command[1..], Commands); - var sugList = new ConsoleInteractive.ConsoleSuggestion.Suggestion[sorted.Count()]; + SendSuggestions(suggestionList.ToArray(), new(offset, offset)); + return; + } - int index = 0; - foreach (var sug in sorted) - sugList[index++] = new(sug.Value); - SendSuggestions(sugList, new(offset, offset + command.Length)); - } - else - { - CommandDispatcher? dispatcher = McClient.dispatcher; - if (dispatcher is null) - return; + if (command[0] == '/' && !command.Contains(' ')) + { + var sorted = Process.ExtractSorted(command[1..], Commands); + var suggestionList = new ConsoleInteractive.ConsoleSuggestion.Suggestion[sorted.Count()]; - ParseResults parse = dispatcher.Parse(command, CmdResult.Empty); + int index = 0; + foreach (var suggestion in sorted) + suggestionList[index++] = new(suggestion.Value); - Brigadier.NET.Suggestion.Suggestions suggestions = await dispatcher.GetCompletionSuggestions(parse, buffer.CursorPosition - offset); + if (cancellationToken.IsCancellationRequested) + return; - int sugLen = suggestions.List.Count; - if (sugLen == 0) - { - DoClearSuggestions(); - return; - } + SendSuggestions(suggestionList, new(offset, offset + command.Length)); + return; + } - Dictionary dictionary = new(); - foreach (var sug in suggestions.List) - dictionary.Add(sug.Text, sug.Tooltip?.String); + CommandDispatcher? dispatcher = McClient.dispatcher; + if (dispatcher is null) + return; - var sugList = new ConsoleInteractive.ConsoleSuggestion.Suggestion[sugLen]; - if (cts.IsCancellationRequested) - return; + ParseResults parse = dispatcher.Parse(command, CmdResult.Empty); + Brigadier.NET.Suggestion.Suggestions suggestions = + await dispatcher.GetCompletionSuggestions(parse, cursorPosition - offset); - Tuple range = new(suggestions.Range.Start + offset, suggestions.Range.End + offset); - var sorted = Process.ExtractSorted(fullCommand[range.Item1..range.Item2], dictionary.Keys); - if (cts.IsCancellationRequested) - return; + if (cancellationToken.IsCancellationRequested) + return; - int index = 0; - foreach (var sug in sorted) - sugList[index++] = new(sug.Value, dictionary[sug.Value] ?? string.Empty); + int suggestionCount = suggestions.List.Count; + if (suggestionCount == 0) + { + DoClearSuggestions(); + return; + } - SendSuggestions(sugList, range); - } - }, cts.Token); - _latestTask = newTask; - try { newTask.Start(); } catch { } - if (_cancellationTokenSource == cts) _cancellationTokenSource = null; + Dictionary tooltips = new(); + foreach (var suggestion in suggestions.List) + tooltips.Add(suggestion.Text, suggestion.Tooltip?.String); + + Tuple range = new(suggestions.Range.Start + offset, suggestions.Range.End + offset); + var sortedSuggestions = Process.ExtractSorted(fullCommand[range.Item1..range.Item2], tooltips.Keys); + if (cancellationToken.IsCancellationRequested) + return; + + var suggestionListWithTooltips = new ConsoleInteractive.ConsoleSuggestion.Suggestion[suggestionCount]; + int suggestionIndex = 0; + foreach (var suggestion in sortedSuggestions) + suggestionListWithTooltips[suggestionIndex++] = new(suggestion.Value, tooltips[suggestion.Value] ?? string.Empty); + + SendSuggestions(suggestionListWithTooltips, range); + } + + private static async Task ObserveAutocompleteTaskAsync(Task task, CancellationTokenSource cancellationTokenSource) + { + try + { + await task; } - else + catch (OperationCanceledException) when (cancellationTokenSource.IsCancellationRequested) + { + } + catch (Exception e) { + if (Settings.Config.Logging.DebugMessages) + WriteLogLine(e.ToString(), acceptnewlines: true); DoClearSuggestions(); - return; + } + finally + { + if (ReferenceEquals(_cancellationTokenSource, cancellationTokenSource)) + _cancellationTokenSource = null; + + cancellationTokenSource.Dispose(); } } diff --git a/MinecraftClient/Crypto/AesCfb8Stream.cs b/MinecraftClient/Crypto/AesCfb8Stream.cs index b60eb84528..381674a53a 100644 --- a/MinecraftClient/Crypto/AesCfb8Stream.cs +++ b/MinecraftClient/Crypto/AesCfb8Stream.cs @@ -1,7 +1,10 @@ using System; +using System.Buffers; using System.IO; using System.Runtime.CompilerServices; -using System.Security.Cryptography; +using System.Threading; +using System.Threading.Tasks; +using MinecraftClient.Crypto.AesHandler; namespace MinecraftClient.Crypto { @@ -9,8 +12,7 @@ public class AesCfb8Stream : Stream { public const int blockSize = 16; - private readonly Aes? Aes = null; - private readonly FastAes? FastAes = null; + private readonly IAesHandler aesHandler; private bool inStreamEnded = false; @@ -22,18 +24,7 @@ public class AesCfb8Stream : Stream public AesCfb8Stream(Stream stream, byte[] key) { BaseStream = stream; - - if (FastAes.IsSupported()) - FastAes = new FastAes(key); - else - { - Aes = Aes.Create(); - Aes.BlockSize = 128; - Aes.KeySize = 128; - Aes.Key = key; - Aes.Mode = CipherMode.ECB; - Aes.Padding = PaddingMode.None; - } + aesHandler = AesHandlerFactory.Create(key); Array.Copy(key, ReadStreamIV, 16); Array.Copy(key, WriteStreamIV, 16); @@ -59,6 +50,11 @@ public override void Flush() BaseStream.Flush(); } + public override Task FlushAsync(CancellationToken cancellationToken) + { + return BaseStream.FlushAsync(cancellationToken); + } + public override long Length { get { throw new NotSupportedException(); } @@ -89,10 +85,7 @@ public override int ReadByte() } Span blockOutput = stackalloc byte[blockSize]; - if (FastAes is not null) - FastAes.EncryptEcb(ReadStreamIV, blockOutput); - else - Aes!.EncryptEcb(ReadStreamIV, blockOutput, PaddingMode.None); + aesHandler.EncryptEcb(ReadStreamIV, blockOutput); // Shift left Array.Copy(ReadStreamIV, 1, ReadStreamIV, 0, blockSize - 1); @@ -101,6 +94,12 @@ public override int ReadByte() return (byte)(blockOutput[0] ^ inputBuf); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void EncryptBlock(ReadOnlySpan blockInput, Span blockOutput) + { + aesHandler.EncryptEcb(blockInput, blockOutput); + } + [MethodImpl(MethodImplOptions.AggressiveOptimization)] public override int Read(byte[] buffer, int outOffset, int required) { @@ -108,43 +107,39 @@ public override int Read(byte[] buffer, int outOffset, int required) return 0; Span blockOutput = stackalloc byte[blockSize]; + byte[] inputBuf = ArrayPool.Shared.Rent(blockSize + required); - byte[] inputBuf = new byte[blockSize + required]; - Array.Copy(ReadStreamIV, inputBuf, blockSize); - - for (int readed = 0, curRead; readed < required; readed += curRead) + try { - curRead = BaseStream.Read(inputBuf, blockSize + readed, required - readed); - if (curRead == 0) - { - inStreamEnded = true; - return readed; - } + Array.Copy(ReadStreamIV, inputBuf, blockSize); - int processEnd = readed + curRead; - if (FastAes is not null) + for (int readed = 0, curRead; readed < required; readed += curRead) { - for (int idx = readed; idx < processEnd; idx++) + curRead = BaseStream.Read(inputBuf, blockSize + readed, required - readed); + if (curRead == 0) { - ReadOnlySpan blockInput = new(inputBuf, idx, blockSize); - FastAes.EncryptEcb(blockInput, blockOutput); - buffer[outOffset + idx] = (byte)(blockOutput[0] ^ inputBuf[idx + blockSize]); + inStreamEnded = true; + Array.Copy(inputBuf, readed, ReadStreamIV, 0, blockSize); + return readed; } - } - else - { + + int processEnd = readed + curRead; for (int idx = readed; idx < processEnd; idx++) { ReadOnlySpan blockInput = new(inputBuf, idx, blockSize); - Aes!.EncryptEcb(blockInput, blockOutput, PaddingMode.None); + EncryptBlock(blockInput, blockOutput); buffer[outOffset + idx] = (byte)(blockOutput[0] ^ inputBuf[idx + blockSize]); } } - } - Array.Copy(inputBuf, required, ReadStreamIV, 0, blockSize); + Array.Copy(inputBuf, required, ReadStreamIV, 0, blockSize); - return required; + return required; + } + finally + { + ArrayPool.Shared.Return(inputBuf); + } } public override long Seek(long offset, SeekOrigin origin) @@ -161,10 +156,7 @@ public override void WriteByte(byte b) { Span blockOutput = stackalloc byte[blockSize]; - if (FastAes is not null) - FastAes.EncryptEcb(WriteStreamIV, blockOutput); - else - Aes!.EncryptEcb(WriteStreamIV, blockOutput, PaddingMode.None); + EncryptBlock(WriteStreamIV, blockOutput); byte outputBuf = (byte)(blockOutput[0] ^ b); @@ -178,22 +170,129 @@ public override void WriteByte(byte b) [MethodImpl(MethodImplOptions.AggressiveOptimization)] public override void Write(byte[] input, int offset, int required) { - byte[] outputBuf = new byte[blockSize + required]; - Array.Copy(WriteStreamIV, outputBuf, blockSize); + byte[] outputBuf = ArrayPool.Shared.Rent(blockSize + required); + + try + { + Array.Copy(WriteStreamIV, outputBuf, blockSize); + + Span blockOutput = stackalloc byte[blockSize]; + for (int written = 0; written < required; ++written) + { + ReadOnlySpan blockInput = new(outputBuf, written, blockSize); + EncryptBlock(blockInput, blockOutput); + outputBuf[blockSize + written] = (byte)(blockOutput[0] ^ input[offset + written]); + } + + BaseStream.Write(outputBuf, blockSize, required); + Array.Copy(outputBuf, required, WriteStreamIV, 0, blockSize); + } + finally + { + ArrayPool.Shared.Return(outputBuf); + } + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (inStreamEnded || buffer.Length == 0) + return 0; + byte[] inputBuf = ArrayPool.Shared.Rent(blockSize + buffer.Length); + + try + { + Array.Copy(ReadStreamIV, inputBuf, blockSize); + + for (int readed = 0; readed < buffer.Length;) + { + int curRead = await BaseStream.ReadAsync(inputBuf.AsMemory(blockSize + readed, buffer.Length - readed), cancellationToken); + if (curRead == 0) + { + inStreamEnded = true; + Array.Copy(inputBuf, readed, ReadStreamIV, 0, blockSize); + return readed; + } + + int processEnd = readed + curRead; + DecryptToOutputBuffer(inputBuf, buffer, readed, processEnd); + readed = processEnd; + } + + Array.Copy(inputBuf, buffer.Length, ReadStreamIV, 0, blockSize); + return buffer.Length; + } + finally + { + ArrayPool.Shared.Return(inputBuf); + } + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (buffer.Length == 0) + return; + + byte[] outputBuf = ArrayPool.Shared.Rent(blockSize + buffer.Length); + + try + { + Array.Copy(WriteStreamIV, outputBuf, blockSize); + EncryptToOutputBuffer(buffer, outputBuf); + + await BaseStream.WriteAsync(outputBuf.AsMemory(blockSize, buffer.Length), cancellationToken); + Array.Copy(outputBuf, buffer.Length, WriteStreamIV, 0, blockSize); + } + finally + { + ArrayPool.Shared.Return(outputBuf); + } + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + aesHandler.Dispose(); + } + + base.Dispose(disposing); + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + private void DecryptToOutputBuffer(byte[] inputBuf, Memory output, int start, int end) + { Span blockOutput = stackalloc byte[blockSize]; - for (int wirtten = 0; wirtten < required; ++wirtten) + for (int idx = start; idx < end; idx++) { - ReadOnlySpan blockInput = new(outputBuf, wirtten, blockSize); - if (FastAes is not null) - FastAes.EncryptEcb(blockInput, blockOutput); - else - Aes!.EncryptEcb(blockInput, blockOutput, PaddingMode.None); - outputBuf[blockSize + wirtten] = (byte)(blockOutput[0] ^ input[offset + wirtten]); + ReadOnlySpan blockInput = new(inputBuf, idx, blockSize); + EncryptBlock(blockInput, blockOutput); + output.Span[idx] = (byte)(blockOutput[0] ^ inputBuf[idx + blockSize]); } - BaseStream.WriteAsync(outputBuf, blockSize, required); + } - Array.Copy(outputBuf, required, WriteStreamIV, 0, blockSize); + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + private void EncryptToOutputBuffer(ReadOnlyMemory input, byte[] outputBuf) + { + Span blockOutput = stackalloc byte[blockSize]; + for (int written = 0; written < input.Length; ++written) + { + ReadOnlySpan blockInput = new(outputBuf, written, blockSize); + EncryptBlock(blockInput, blockOutput); + outputBuf[blockSize + written] = (byte)(blockOutput[0] ^ input.Span[written]); + } } } } diff --git a/MinecraftClient/Crypto/AesHandler/BasicAes.cs b/MinecraftClient/Crypto/AesHandler/BasicAes.cs new file mode 100644 index 0000000000..8c08572fcf --- /dev/null +++ b/MinecraftClient/Crypto/AesHandler/BasicAes.cs @@ -0,0 +1,31 @@ +using System; +using System.Security.Cryptography; + +namespace MinecraftClient.Crypto.AesHandler; + +public sealed class BasicAes : IAesHandler +{ + private readonly Aes aes; + + public BasicAes(byte[] key) + { + ArgumentNullException.ThrowIfNull(key); + + aes = Aes.Create(); + aes.BlockSize = 128; + aes.KeySize = 128; + aes.Key = key; + aes.Mode = CipherMode.ECB; + aes.Padding = PaddingMode.None; + } + + public override void EncryptEcb(ReadOnlySpan plaintext, Span destination) + { + aes.EncryptEcb(plaintext, destination, PaddingMode.None); + } + + public override void Dispose() + { + aes.Dispose(); + } +} diff --git a/MinecraftClient/Crypto/AesHandler/FasterAesArm.cs b/MinecraftClient/Crypto/AesHandler/FasterAesArm.cs new file mode 100644 index 0000000000..a5b8621d6b --- /dev/null +++ b/MinecraftClient/Crypto/AesHandler/FasterAesArm.cs @@ -0,0 +1,162 @@ +using System; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; + +namespace MinecraftClient.Crypto.AesHandler; + +public sealed class FasterAesArm : IAesHandler +{ + private const int BlockSize = 16; + private const int Rounds = 10; + + private readonly byte[] enc; + + public FasterAesArm(ReadOnlySpan key) + { + enc = new byte[(Rounds + 1) * BlockSize]; + + int[] intKey = GenerateKeyExpansion(key); + for (int i = 0; i < intKey.Length; ++i) + { + enc[i * 4 + 0] = (byte)((intKey[i] >> 0) & 0xFF); + enc[i * 4 + 1] = (byte)((intKey[i] >> 8) & 0xFF); + enc[i * 4 + 2] = (byte)((intKey[i] >> 16) & 0xFF); + enc[i * 4 + 3] = (byte)((intKey[i] >> 24) & 0xFF); + } + } + + public static bool IsSupported() + { + return Aes.IsSupported && AdvSimd.IsSupported; + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public override void EncryptEcb(ReadOnlySpan plaintext, Span destination) + { + int position = 0; + int left = plaintext.Length; + + Vector128 key0 = Unsafe.ReadUnaligned>(ref enc[0 * BlockSize]); + Vector128 key1 = Unsafe.ReadUnaligned>(ref enc[1 * BlockSize]); + Vector128 key2 = Unsafe.ReadUnaligned>(ref enc[2 * BlockSize]); + Vector128 key3 = Unsafe.ReadUnaligned>(ref enc[3 * BlockSize]); + Vector128 key4 = Unsafe.ReadUnaligned>(ref enc[4 * BlockSize]); + Vector128 key5 = Unsafe.ReadUnaligned>(ref enc[5 * BlockSize]); + Vector128 key6 = Unsafe.ReadUnaligned>(ref enc[6 * BlockSize]); + Vector128 key7 = Unsafe.ReadUnaligned>(ref enc[7 * BlockSize]); + Vector128 key8 = Unsafe.ReadUnaligned>(ref enc[8 * BlockSize]); + Vector128 key9 = Unsafe.ReadUnaligned>(ref enc[9 * BlockSize]); + Vector128 key10 = Unsafe.ReadUnaligned>(ref enc[10 * BlockSize]); + + while (left >= BlockSize) + { + Vector128 block = Unsafe.ReadUnaligned>(ref Unsafe.AsRef(in plaintext[position])); + + block = Aes.Encrypt(block, key0); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key1); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key2); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key3); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key4); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key5); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key6); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key7); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key8); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key9); + block = AdvSimd.Xor(block, key10); + + Unsafe.WriteUnaligned(ref destination[position], block); + + position += BlockSize; + left -= BlockSize; + } + } + + private static int[] GenerateKeyExpansion(ReadOnlySpan rgbKey) + { + int[] encryptKeyExpansion = new int[4 * (Rounds + 1)]; + + int index = 0; + for (int i = 0; i < 4; ++i) + { + int i0 = rgbKey[index++]; + int i1 = rgbKey[index++]; + int i2 = rgbKey[index++]; + int i3 = rgbKey[index++]; + encryptKeyExpansion[i] = i3 << 24 | i2 << 16 | i1 << 8 | i0; + } + + for (int i = 4; i < 4 * (Rounds + 1); ++i) + { + int temp = encryptKeyExpansion[i - 1]; + + if (i % 4 == 0) + { + temp = SubWord(Rot3(temp)); + temp ^= Rcon[(i / 4) - 1]; + } + + encryptKeyExpansion[i] = encryptKeyExpansion[i - 4] ^ temp; + } + + return encryptKeyExpansion; + } + + private static int SubWord(int value) + { + return Sbox[value & 0xFF] + | Sbox[(value >> 8) & 0xFF] << 8 + | Sbox[(value >> 16) & 0xFF] << 16 + | Sbox[(value >> 24) & 0xFF] << 24; + } + + private static int Rot3(int value) + { + return (value << 24 & unchecked((int)0xFF000000)) | (value >> 8 & unchecked((int)0x00FFFFFF)); + } + + private static ReadOnlySpan Sbox => + [ + 99, 124, 119, 123, 242, 107, 111, 197, 48, 1, 103, 43, 254, 215, 171, 118, + 202, 130, 201, 125, 250, 89, 71, 240, 173, 212, 162, 175, 156, 164, 114, 192, + 183, 253, 147, 38, 54, 63, 247, 204, 52, 165, 229, 241, 113, 216, 49, 21, + 4, 199, 35, 195, 24, 150, 5, 154, 7, 18, 128, 226, 235, 39, 178, 117, + 9, 131, 44, 26, 27, 110, 90, 160, 82, 59, 214, 179, 41, 227, 47, 132, + 83, 209, 0, 237, 32, 252, 177, 91, 106, 203, 190, 57, 74, 76, 88, 207, + 208, 239, 170, 251, 67, 77, 51, 133, 69, 249, 2, 127, 80, 60, 159, 168, + 81, 163, 64, 143, 146, 157, 56, 245, 188, 182, 218, 33, 16, 255, 243, 210, + 205, 12, 19, 236, 95, 151, 68, 23, 196, 167, 126, 61, 100, 93, 25, 115, + 96, 129, 79, 220, 34, 42, 144, 136, 70, 238, 184, 20, 222, 94, 11, 219, + 224, 50, 58, 10, 73, 6, 36, 92, 194, 211, 172, 98, 145, 149, 228, 121, + 231, 200, 55, 109, 141, 213, 78, 169, 108, 86, 244, 234, 101, 122, 174, 8, + 186, 120, 37, 46, 28, 166, 180, 198, 232, 221, 116, 31, 75, 189, 139, 138, + 112, 62, 181, 102, 72, 3, 246, 14, 97, 53, 87, 185, 134, 193, 29, 158, + 225, 248, 152, 17, 105, 217, 142, 148, 155, 30, 135, 233, 206, 85, 40, 223, + 140, 161, 137, 13, 191, 230, 66, 104, 65, 153, 45, 15, 176, 84, 187, 22 + ]; + + private static ReadOnlySpan Rcon => + [ + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, + 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, 0x2F, 0x5E, 0xBC, 0x63, 0xC6, + 0x97, 0x35, 0x6A, 0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91 + ]; +} diff --git a/MinecraftClient/Crypto/AesHandler/FasterAesX86.cs b/MinecraftClient/Crypto/AesHandler/FasterAesX86.cs new file mode 100644 index 0000000000..715f8d9ef6 --- /dev/null +++ b/MinecraftClient/Crypto/AesHandler/FasterAesX86.cs @@ -0,0 +1,89 @@ +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; + +namespace MinecraftClient.Crypto.AesHandler; + +public sealed class FasterAesX86 : IAesHandler +{ + private Vector128[] RoundKeys { get; } + + public FasterAesX86(ReadOnlySpan key) + { + RoundKeys = KeyExpansion(key); + } + + public static bool IsSupported() + { + return Sse2.IsSupported && Aes.IsSupported; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public override void EncryptEcb(ReadOnlySpan plaintext, Span destination) + { + Vector128[] keys = RoundKeys; + + ReadOnlySpan> blocks = MemoryMarshal.Cast>(plaintext); + Span> dest = MemoryMarshal.Cast>(destination); + + _ = keys[10]; + + for (int i = 0; i < blocks.Length; i++) + { + Vector128 b = blocks[i]; + + b = Sse2.Xor(b, keys[0]); + b = Aes.Encrypt(b, keys[1]); + b = Aes.Encrypt(b, keys[2]); + b = Aes.Encrypt(b, keys[3]); + b = Aes.Encrypt(b, keys[4]); + b = Aes.Encrypt(b, keys[5]); + b = Aes.Encrypt(b, keys[6]); + b = Aes.Encrypt(b, keys[7]); + b = Aes.Encrypt(b, keys[8]); + b = Aes.Encrypt(b, keys[9]); + b = Aes.EncryptLast(b, keys[10]); + + dest[i] = b; + } + } + + private static Vector128[] KeyExpansion(ReadOnlySpan key) + { + Vector128[] keys = new Vector128[20]; + + keys[0] = Unsafe.ReadUnaligned>(ref MemoryMarshal.GetReference(key)); + + MakeRoundKey(keys, 1, 0x01); + MakeRoundKey(keys, 2, 0x02); + MakeRoundKey(keys, 3, 0x04); + MakeRoundKey(keys, 4, 0x08); + MakeRoundKey(keys, 5, 0x10); + MakeRoundKey(keys, 6, 0x20); + MakeRoundKey(keys, 7, 0x40); + MakeRoundKey(keys, 8, 0x80); + MakeRoundKey(keys, 9, 0x1B); + MakeRoundKey(keys, 10, 0x36); + + for (int i = 1; i < 10; i++) + keys[10 + i] = Aes.InverseMixColumns(keys[i]); + + return keys; + } + + private static void MakeRoundKey(Vector128[] keys, int index, byte rcon) + { + Vector128 s = keys[index - 1]; + Vector128 t = keys[index - 1]; + + t = Aes.KeygenAssist(t, rcon); + t = Sse2.Shuffle(t.AsUInt32(), 0xFF).AsByte(); + + s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 4)); + s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 8)); + + keys[index] = Sse2.Xor(s, t); + } +} diff --git a/MinecraftClient/Crypto/AesHandlerFactory.cs b/MinecraftClient/Crypto/AesHandlerFactory.cs new file mode 100644 index 0000000000..0c2e7b3044 --- /dev/null +++ b/MinecraftClient/Crypto/AesHandlerFactory.cs @@ -0,0 +1,20 @@ +using System; +using MinecraftClient.Crypto.AesHandler; + +namespace MinecraftClient.Crypto; + +internal static class AesHandlerFactory +{ + public static IAesHandler Create(ReadOnlySpan key) + { + byte[] ownedKey = key.ToArray(); + + if (FasterAesX86.IsSupported()) + return new FasterAesX86(ownedKey); + + if (FasterAesArm.IsSupported()) + return new FasterAesArm(ownedKey); + + return new BasicAes(ownedKey); + } +} diff --git a/MinecraftClient/Crypto/IAesHandler.cs b/MinecraftClient/Crypto/IAesHandler.cs new file mode 100644 index 0000000000..1f34ec9001 --- /dev/null +++ b/MinecraftClient/Crypto/IAesHandler.cs @@ -0,0 +1,12 @@ +using System; + +namespace MinecraftClient.Crypto; + +public abstract class IAesHandler : IDisposable +{ + public abstract void EncryptEcb(ReadOnlySpan plaintext, Span destination); + + public virtual void Dispose() + { + } +} diff --git a/MinecraftClient/FileMonitor.cs b/MinecraftClient/FileMonitor.cs index 16f590a4f1..d67610afa1 100644 --- a/MinecraftClient/FileMonitor.cs +++ b/MinecraftClient/FileMonitor.cs @@ -3,6 +3,7 @@ using System.IO; using System.Text; using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient { @@ -12,7 +13,7 @@ namespace MinecraftClient public class FileMonitor : IDisposable { private readonly Tuple? monitor = null; - private readonly Tuple? polling = null; + private readonly Tuple? polling = null; /// /// Create a new FileMonitor and start monitoring @@ -48,9 +49,9 @@ public FileMonitor(string folder, string filename, FileSystemEventHandler handle monitor = null; var cancellationTokenSource = new CancellationTokenSource(); - polling = new Tuple(new Thread(() => PollingThread(folder, filename, handler, cancellationTokenSource.Token)), cancellationTokenSource); - polling.Item1.Name = String.Format("{0} Polling thread: {1}", GetType().Name, Path.Combine(folder, filename)); - polling.Item1.Start(); + polling = new Tuple( + Task.Run(() => PollingLoopAsync(folder, filename, handler, cancellationTokenSource.Token), cancellationTokenSource.Token), + cancellationTokenSource); } } @@ -66,25 +67,29 @@ public void Dispose() } /// - /// Fallback polling thread for use when operating system does not support FileSystemWatcher + /// Fallback polling loop for use when operating system does not support FileSystemWatcher /// /// Folder to monitor /// File name to monitor /// Callback when file changes - private void PollingThread(string folder, string filename, FileSystemEventHandler handler, CancellationToken cancellationToken) + private async Task PollingLoopAsync(string folder, string filename, FileSystemEventHandler handler, CancellationToken cancellationToken) { string filePath = Path.Combine(folder, filename); DateTime lastWrite = GetLastWrite(filePath); - while (!cancellationToken.IsCancellationRequested) + using PeriodicTimer periodicTimer = new(TimeSpan.FromSeconds(5)); + try { - Thread.Sleep(5000); - DateTime lastWriteNew = GetLastWrite(filePath); - if (lastWriteNew != lastWrite) + while (await periodicTimer.WaitForNextTickAsync(cancellationToken)) { - lastWrite = lastWriteNew; - handler(this, new FileSystemEventArgs(WatcherChangeTypes.Changed, folder, filename)); + DateTime lastWriteNew = GetLastWrite(filePath); + if (lastWriteNew != lastWrite) + { + lastWrite = lastWriteNew; + handler(this, new FileSystemEventArgs(WatcherChangeTypes.Changed, folder, filename)); + } } } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { } } /// diff --git a/MinecraftClient/IConsoleBackend.cs b/MinecraftClient/IConsoleBackend.cs index b4d52ebf57..4759e0f865 100644 --- a/MinecraftClient/IConsoleBackend.cs +++ b/MinecraftClient/IConsoleBackend.cs @@ -1,4 +1,6 @@ using System; +using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient { @@ -56,8 +58,12 @@ public interface IConsoleBackend string RequestImmediateInput(); + Task RequestImmediateInputAsync(CancellationToken cancellationToken); + string? ReadPassword(); + Task ReadPasswordAsync(CancellationToken cancellationToken); + void ClearInputBuffer(); bool DisplayUserInput { get; set; } diff --git a/MinecraftClient/MainThreadExecutionScope.cs b/MinecraftClient/MainThreadExecutionScope.cs new file mode 100644 index 0000000000..226e41ebdb --- /dev/null +++ b/MinecraftClient/MainThreadExecutionScope.cs @@ -0,0 +1,45 @@ +using System; +using System.Threading; + +namespace MinecraftClient +{ + internal static class MainThreadExecutionScope + { + private sealed class ScopeNode(object owner, ScopeNode? parent) : IDisposable + { + public object Owner { get; } = owner; + public ScopeNode? Parent { get; } = parent; + + public void Dispose() + { + if (!ReferenceEquals(s_currentScope.Value, this)) + throw new InvalidOperationException("Main-thread execution scope disposed out of order."); + + s_currentScope.Value = Parent; + } + } + + private static readonly AsyncLocal s_currentScope = new(); + + public static IDisposable Enter(object owner) + { + ScopeNode scopeNode = new(owner, s_currentScope.Value); + s_currentScope.Value = scopeNode; + return scopeNode; + } + + public static bool IsActive(object owner) + { + ScopeNode? scopeNode = s_currentScope.Value; + while (scopeNode is not null) + { + if (ReferenceEquals(scopeNode.Owner, owner)) + return true; + + scopeNode = scopeNode.Parent; + } + + return false; + } + } +} diff --git a/MinecraftClient/Mapping/Movement.cs b/MinecraftClient/Mapping/Movement.cs index 0e972e09fd..2bc3807c7a 100644 --- a/MinecraftClient/Mapping/Movement.cs +++ b/MinecraftClient/Mapping/Movement.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Threading; -using System.Threading.Tasks; namespace MinecraftClient.Mapping { @@ -150,17 +149,8 @@ public static Queue Move2Steps(Location start, Location goal, ref doub public static Queue? CalculatePath(World world, Location start, Location goal, bool allowUnsafe, int maxOffset, int minOffset, TimeSpan timeout) { - CancellationTokenSource cts = new(); - Task?> pathfindingTask = Task.Factory.StartNew(() => - CalculatePath(world, start, goal, allowUnsafe, maxOffset, minOffset, cts.Token)); - pathfindingTask.Wait(timeout); - if (!pathfindingTask.IsCompleted) - { - cts.Cancel(); - pathfindingTask.Wait(); - } - - return pathfindingTask.Result; + using CancellationTokenSource cts = new(timeout); + return CalculatePath(world, start, goal, allowUnsafe, maxOffset, minOffset, cts.Token); } /// @@ -713,4 +703,4 @@ public static bool CheckChunkLoading(World world, Location start, Location dest) return true; } } -} \ No newline at end of file +} diff --git a/MinecraftClient/McClient.cs b/MinecraftClient/McClient.cs index 4d23bb60ca..0ccaaf7599 100644 --- a/MinecraftClient/McClient.cs +++ b/MinecraftClient/McClient.cs @@ -4,6 +4,8 @@ using System.Net.Sockets; using System.Text; using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; using Brigadier.NET; using Brigadier.NET.Exceptions; using MinecraftClient.ChatBots; @@ -42,10 +44,12 @@ public class McClient : IMinecraftComHandler private readonly Queue chatQueue = new(); private static DateTime nextMessageSendTime = DateTime.MinValue; - private readonly Queue threadTasks = new(); + private Queue threadTasks = new(); private readonly Lock threadTasksLock = new(); private readonly Lock recipeBookLock = new(); private readonly Lock achievementsLock = new(); + private readonly Lock consoleCommandProcessingLock = new(); + private readonly Lock networkAutoCompleteLock = new(); private readonly List bots = new(); private static readonly List botsOnHold = new(); @@ -223,7 +227,11 @@ public Dictionary GetTeams() IMinecraftCom handler = null!; SessionToken _sessionToken; CancellationTokenSource? cmdprompt = null; - Tuple? timeoutdetector = null; + private Channel? consoleCommandChannel; + private Task? consoleCommandProcessingTask; + private TaskCompletionSource? pendingNetworkAutoCompleteRequest; + private TaskCompletionSource? pendingCommandListInitialization; + Tuple? timeoutdetector = null; private int transferInProgress = 0; public ILogger Log; @@ -310,9 +318,10 @@ public McClient(SessionToken session, PlayerKeyPair? playerKeyPair, string serve handler = Protocol.ProtocolHandler.GetProtocolHandler(client, protocolversion, forgeInfo, this); Log.Info(Translations.mcc_version_supported); - timeoutdetector = new(new Thread(new ParameterizedThreadStart(TimeoutDetector)), new CancellationTokenSource()); - timeoutdetector.Item1.Name = "MCC Connection timeout detector"; - timeoutdetector.Item1.Start(timeoutdetector.Item2.Token); + CancellationTokenSource timeoutDetectorCancellationTokenSource = new(); + Task timeoutDetectorTask = TimeoutDetectorAsync(timeoutDetectorCancellationTokenSource.Token); + timeoutdetector = new(timeoutDetectorTask, timeoutDetectorCancellationTokenSource); + _ = ObserveTimeoutDetectorAsync(timeoutDetectorTask, timeoutDetectorCancellationTokenSource.Token); try { @@ -324,10 +333,7 @@ public McClient(SessionToken session, PlayerKeyPair? playerKeyPair, string serve Log.Info(string.Format(Translations.mcc_joined, Config.Main.Advanced.InternalCmdChar.ToLogString())); - cmdprompt = new CancellationTokenSource(); - ConsoleIO.Backend.BeginReadThread(); - ConsoleIO.Backend.MessageReceived += ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange += ConsoleIO.AutocompleteHandler; + StartConsoleHandlers(); } else { @@ -363,15 +369,12 @@ public McClient(SessionToken session, PlayerKeyPair? playerKeyPair, string serve if (ReconnectionAttemptsLeft > 0) { Log.Info(string.Format(Translations.mcc_reconnect, ReconnectionAttemptsLeft)); - Thread.Sleep(5000); ReconnectionAttemptsLeft--; - Program.Restart(); + Program.Restart(5, announceDelay: false); } else if (InternalConfig.InteractiveMode) { - ConsoleIO.Backend.StopReadThread(); - ConsoleIO.Backend.MessageReceived -= ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange -= ConsoleIO.AutocompleteHandler; + StopConsoleHandlers(); Program.HandleFailure(); } @@ -389,9 +392,7 @@ public McClient(SessionToken session, PlayerKeyPair? playerKeyPair, string serve // kick messages and Ignore_Kick_Message is false, or retry limit reached) if (InternalConfig.InteractiveMode) { - ConsoleIO.Backend.StopReadThread(); - ConsoleIO.Backend.MessageReceived -= ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange -= ConsoleIO.AutocompleteHandler; + StopConsoleHandlers(); Program.HandleFailure(); } @@ -415,6 +416,7 @@ public void Transfer(string newHost, int newPort) try { Log.Info($"Initiating a transfer to: {newHost}:{newPort}"); + StopConsoleHandlers(); // Unload bots UnloadAllBots(); @@ -449,10 +451,7 @@ public void Transfer(string newHost, int newPort) UpdateKeepAlive(); Log.Info($"Successfully transferred connection and logged in to {newHost}:{newPort}."); - cmdprompt = new CancellationTokenSource(); - ConsoleIO.Backend.BeginReadThread(); - ConsoleIO.Backend.MessageReceived += ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange += ConsoleIO.AutocompleteHandler; + StartConsoleHandlers(); } else { @@ -490,15 +489,12 @@ public void Transfer(string newHost, int newPort) if (ReconnectionAttemptsLeft > 0) { Log.Info($"Reconnecting... Attempts left: {ReconnectionAttemptsLeft}"); - Thread.Sleep(5000); ReconnectionAttemptsLeft--; - Program.Restart(); + Program.Restart(5, announceDelay: false); } else if (InternalConfig.InteractiveMode) { - ConsoleIO.Backend.StopReadThread(); - ConsoleIO.Backend.MessageReceived -= ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange -= ConsoleIO.AutocompleteHandler; + StopConsoleHandlers(); Program.HandleFailure(); } @@ -703,15 +699,22 @@ public void OnUpdate() } } + Queue? pendingThreadTasks = null; lock (threadTasksLock) { - while (threadTasks.Count > 0) + if (threadTasks.Count > 0) { - Action taskToRun = threadTasks.Dequeue(); - taskToRun(); + pendingThreadTasks = threadTasks; + threadTasks = new(); } } + if (pendingThreadTasks is not null) + { + while (pendingThreadTasks.Count > 0) + pendingThreadTasks.Dequeue().ExecuteSynchronously(); + } + lock (DigLock) { if (RemainingDiggingTime > 0) @@ -734,29 +737,44 @@ public void OnUpdate() /// /// Periodically checks for server keepalives and consider that connection has been lost if the last received keepalive is too old. /// - private void TimeoutDetector(object? o) + private async Task TimeoutDetectorAsync(CancellationToken cancellationToken) { UpdateKeepAlive(); - do + using PeriodicTimer periodicTimer = new(TimeSpan.FromSeconds(15)); + try { - Thread.Sleep(TimeSpan.FromSeconds(15)); - - if (((CancellationToken)o!).IsCancellationRequested) - return; - - lock (lastKeepAliveLock) + while (await periodicTimer.WaitForNextTickAsync(cancellationToken)) { - if (lastKeepAlive.AddSeconds(Config.Main.Advanced.TcpTimeout) < DateTime.Now) + lock (lastKeepAliveLock) { - if (((CancellationToken)o!).IsCancellationRequested) - return; + if (lastKeepAlive.AddSeconds(Config.Main.Advanced.TcpTimeout) < DateTime.Now) + { + cancellationToken.ThrowIfCancellationRequested(); - OnConnectionLost(ChatBot.DisconnectReason.ConnectionLost, Translations.error_timeout); - return; + OnConnectionLost(ChatBot.DisconnectReason.ConnectionLost, Translations.error_timeout); + return; + } } } } - while (!((CancellationToken)o!).IsCancellationRequested); + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + } + + private async Task ObserveTimeoutDetectorAsync(Task timeoutDetectorTask, CancellationToken cancellationToken) + { + try + { + await timeoutDetectorTask; + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + catch (Exception e) + { + Log.Warn(e.ToString()); + } } /// @@ -770,6 +788,259 @@ private void UpdateKeepAlive() } } + private void StartConsoleHandlers() + { + if (ConsoleIO.Backend is null) + return; + + cmdprompt = new CancellationTokenSource(); + StartConsoleCommandProcessing(cmdprompt.Token); + ConsoleIO.Backend.BeginReadThread(); + ConsoleIO.Backend.MessageReceived += ConsoleReaderOnMessageReceived; + ConsoleIO.Backend.OnInputChange += ConsoleIO.AutocompleteHandler; + } + + private void StopConsoleHandlers() + { + if (ConsoleIO.Backend is not null) + { + ConsoleIO.Backend.StopReadThread(); + ConsoleIO.Backend.MessageReceived -= ConsoleReaderOnMessageReceived; + ConsoleIO.Backend.OnInputChange -= ConsoleIO.AutocompleteHandler; + } + + StopConsoleCommandProcessing(); + } + + private void StartConsoleCommandProcessing(CancellationToken cancellationToken) + { + lock (consoleCommandProcessingLock) + { + consoleCommandChannel = Channel.CreateUnbounded(new UnboundedChannelOptions() + { + SingleReader = true, + SingleWriter = false, + AllowSynchronousContinuations = false + }); + consoleCommandProcessingTask = ProcessConsoleMessagesAsync(consoleCommandChannel.Reader, cancellationToken); + _ = ObserveConsoleCommandProcessingAsync(consoleCommandProcessingTask, cancellationToken); + } + } + + private void StopConsoleCommandProcessing() + { + Channel? activeChannel; + + lock (consoleCommandProcessingLock) + { + activeChannel = consoleCommandChannel; + consoleCommandChannel = null; + } + + activeChannel?.Writer.TryComplete(); + + if (cmdprompt is not null) + { + cmdprompt.Cancel(); + cmdprompt = null; + } + + CancelPendingNetworkAutoComplete(); + CancelPendingCommandListInitialization(); + } + + private async Task ObserveConsoleCommandProcessingAsync(Task processingTask, CancellationToken cancellationToken) + { + try + { + await processingTask; + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + catch (Exception e) + { + Log.Warn(e.ToString()); + } + finally + { + lock (consoleCommandProcessingLock) + { + if (ReferenceEquals(consoleCommandProcessingTask, processingTask)) + consoleCommandProcessingTask = null; + } + } + } + + private async Task ProcessConsoleMessagesAsync(ChannelReader channelReader, CancellationToken cancellationToken) + { + await foreach (string message in channelReader.ReadAllAsync(cancellationToken)) + { + if (cancellationToken.IsCancellationRequested) + return; + + if (TryParseBasicIoAutocompleteRequest(message, out _)) + await HandleBasicIoAutocompleteRequestAsync(message, cancellationToken); + else + await InvokeOnMainThreadAsync(() => HandleCommandPromptText(message)); + } + } + + private async Task HandleBasicIoAutocompleteRequestAsync(string text, CancellationToken cancellationToken) + { + try + { + string[] command = text[1..].Split((char)0x00); + if (command.Length < 2 || !command[0].Equals("autocomplete", StringComparison.OrdinalIgnoreCase)) + return; + + await WaitForCommandListInitializationAsync(cancellationToken); + + Task requestTask = InvokeRequired + ? await InvokeOnMainThreadAsync(() => BeginNetworkAutoCompleteRequest(command[1])) + : BeginNetworkAutoCompleteRequest(command[1]); + + await requestTask.WaitAsync(cancellationToken); + + if (command.Length > 1) + ConsoleIO.WriteLine((char)0x00 + "autocomplete" + (char)0x00 + ConsoleIO.AutoCompleteResult); + else ConsoleIO.WriteLine((char)0x00 + "autocomplete" + (char)0x00); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + catch (OperationCanceledException) + { + } + } + + private static bool TryParseBasicIoAutocompleteRequest(string text, out string behindCursor) + { + behindCursor = string.Empty; + + if (!ConsoleIO.BasicIO || string.IsNullOrEmpty(text) || text[0] != (char)0x00) + return false; + + string[] command = text[1..].Split((char)0x00); + if (command.Length < 2 || !command[0].Equals("autocomplete", StringComparison.OrdinalIgnoreCase)) + return false; + + behindCursor = command[1]; + return true; + } + + private Task BeginNetworkAutoCompleteRequest(string behindCursor) + { + if (string.IsNullOrEmpty(behindCursor)) + return Task.FromResult(Array.Empty()); + + TaskCompletionSource request = new(TaskCreationOptions.RunContinuationsAsynchronously); + lock (networkAutoCompleteLock) + { + pendingNetworkAutoCompleteRequest?.TrySetException(new OperationCanceledException()); + pendingNetworkAutoCompleteRequest = request; + } + + try + { + if (handler.AutoComplete(behindCursor) < 0) + { + CompletePendingNetworkAutoComplete(Array.Empty()); + } + } + catch (Exception e) + { + lock (networkAutoCompleteLock) + { + if (ReferenceEquals(pendingNetworkAutoCompleteRequest, request)) + pendingNetworkAutoCompleteRequest = null; + } + request.TrySetException(e); + } + + return request.Task; + } + + private void BeginCommandListInitialization() + { + lock (networkAutoCompleteLock) + { + pendingCommandListInitialization?.TrySetCanceled(); + pendingCommandListInitialization = new(TaskCreationOptions.RunContinuationsAsynchronously); + } + } + + private void CompletePendingNetworkAutoComplete(string[] result) + { + TaskCompletionSource? pendingRequest; + lock (networkAutoCompleteLock) + { + pendingRequest = pendingNetworkAutoCompleteRequest; + pendingNetworkAutoCompleteRequest = null; + } + + pendingRequest?.TrySetResult(result); + } + + private void CancelPendingNetworkAutoComplete() + { + TaskCompletionSource? pendingRequest; + lock (networkAutoCompleteLock) + { + pendingRequest = pendingNetworkAutoCompleteRequest; + pendingNetworkAutoCompleteRequest = null; + } + + pendingRequest?.TrySetCanceled(); + } + + private void CompletePendingCommandListInitialization() + { + TaskCompletionSource? pendingInitialization; + lock (networkAutoCompleteLock) + { + pendingInitialization = pendingCommandListInitialization; + pendingCommandListInitialization = null; + } + + pendingInitialization?.TrySetResult(true); + } + + private void CancelPendingCommandListInitialization() + { + TaskCompletionSource? pendingInitialization; + lock (networkAutoCompleteLock) + { + pendingInitialization = pendingCommandListInitialization; + pendingCommandListInitialization = null; + } + + pendingInitialization?.TrySetCanceled(); + } + + private async Task WaitForCommandListInitializationAsync(CancellationToken cancellationToken) + { + Task? initializationTask; + lock (networkAutoCompleteLock) + { + initializationTask = pendingCommandListInitialization?.Task; + } + + if (initializationTask is null) + return; + + try + { + await initializationTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken); + } + catch (TimeoutException) + { + } + catch (OperationCanceledException) + { + } + } + /// /// Disconnect the client from the server (initiated from MCC) /// @@ -781,6 +1052,7 @@ public void Disconnect() botsOnHold.Clear(); botsOnHold.AddRange(bots); + StopConsoleHandlers(); if (handler is not null) { @@ -788,12 +1060,6 @@ public void Disconnect() handler.Dispose(); } - if (cmdprompt is not null) - { - cmdprompt.Cancel(); - cmdprompt = null; - } - if (timeoutdetector is not null) { timeoutdetector.Item2.Cancel(); @@ -820,8 +1086,7 @@ public void OnConnectionLost(ChatBot.DisconnectReason reason, string message) if (timeoutdetector is not null) { - if (timeoutdetector is not null && Thread.CurrentThread != timeoutdetector.Item1) - timeoutdetector.Item2.Cancel(); + timeoutdetector.Item2.Cancel(); timeoutdetector = null; } @@ -872,9 +1137,7 @@ public void OnConnectionLost(ChatBot.DisconnectReason reason, string message) if (!will_restart) { - ConsoleIO.Backend.StopReadThread(); - ConsoleIO.Backend.MessageReceived -= ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange -= ConsoleIO.AutocompleteHandler; + StopConsoleHandlers(); Program.HandleFailure(null, false, reason); } } @@ -885,16 +1148,18 @@ public void OnConnectionLost(ChatBot.DisconnectReason reason, string message) private void ConsoleReaderOnMessageReceived(object? sender, string e) { + Channel? activeChannel; + lock (consoleCommandProcessingLock) + { + activeChannel = consoleCommandChannel; + } - if (client.Client is null) + if (activeChannel is null || client.Client is null) return; if (client.Client.Connected) { - new Thread(() => - { - InvokeOnMainThread(() => HandleCommandPromptText(e)); - }).Start(); + activeChannel.Writer.TryWrite(e); } else return; @@ -916,56 +1181,45 @@ private void HandleCommandPromptText(string text) { if (ConsoleIO.BasicIO && text.Length > 0 && text[0] == (char)0x00) { - //Process a request from the GUI - string[] command = text[1..].Split((char)0x00); - switch (command[0].ToLower()) - { - case "autocomplete": - int id = handler.AutoComplete(command[1]); - while (!ConsoleIO.AutoCompleteDone) { Thread.Sleep(100); } - if (command.Length > 1) { ConsoleIO.WriteLine((char)0x00 + "autocomplete" + (char)0x00 + ConsoleIO.AutoCompleteResult); } - else ConsoleIO.WriteLine((char)0x00 + "autocomplete" + (char)0x00); - break; - } + _ = HandleBasicIoAutocompleteRequestAsync(text, CancellationToken.None); + return; } - else - { - text = text.Trim(); - if (text.Length > 1 - && Config.Main.Advanced.InternalCmdChar == MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none - && text[0] == '/') - { - SendText(text); - } - else if (text.Length > 2 - && Config.Main.Advanced.InternalCmdChar != MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none - && text[0] == Config.Main.Advanced.InternalCmdChar.ToChar() - && text[1] == '/') - { - SendText(text[1..]); - } - else if (text.Length > 0) + text = text.Trim(); + + if (text.Length > 1 + && Config.Main.Advanced.InternalCmdChar == MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none + && text[0] == '/') + { + SendText(text); + } + else if (text.Length > 2 + && Config.Main.Advanced.InternalCmdChar != MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none + && text[0] == Config.Main.Advanced.InternalCmdChar.ToChar() + && text[1] == '/') + { + SendText(text[1..]); + } + else if (text.Length > 0) + { + if (Config.Main.Advanced.InternalCmdChar == MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none + || text[0] == Config.Main.Advanced.InternalCmdChar.ToChar()) { - if (Config.Main.Advanced.InternalCmdChar == MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none - || text[0] == Config.Main.Advanced.InternalCmdChar.ToChar()) + CmdResult result = new(); + string command = Config.Main.Advanced.InternalCmdChar.ToChar() == ' ' ? text : text[1..]; + if (!PerformInternalCommand(Config.AppVar.ExpandVars(command), ref result, Settings.Config.AppVar.GetVariables()) && Config.Main.Advanced.InternalCmdChar.ToChar() == '/') { - CmdResult result = new(); - string command = Config.Main.Advanced.InternalCmdChar.ToChar() == ' ' ? text : text[1..]; - if (!PerformInternalCommand(Config.AppVar.ExpandVars(command), ref result, Settings.Config.AppVar.GetVariables()) && Config.Main.Advanced.InternalCmdChar.ToChar() == '/') - { - SendText(text); - } - else if (result.status != CmdResult.Status.NotRun && (result.status != CmdResult.Status.Done || !string.IsNullOrWhiteSpace(result.result))) - { - Log.Info(result); - } + SendText(text); } - else + else if (result.status != CmdResult.Status.NotRun && (result.status != CmdResult.Status.Done || !string.IsNullOrWhiteSpace(result.result))) { - SendText(text); + Log.Info(result); } } + else + { + SendText(text); + } } } @@ -1099,19 +1353,7 @@ public void UnloadAllBots() /// Type of the return value public T InvokeOnMainThread(Func task) { - if (!InvokeRequired) - { - return task(); - } - else - { - TaskWithResult taskWithResult = new(task); - lock (threadTasksLock) - { - threadTasks.Enqueue(taskWithResult.ExecuteSynchronously); - } - return taskWithResult.WaitGetResult(); - } + return InvokeOnMainThreadAsync(task).GetAwaiter().GetResult(); } /// @@ -1126,6 +1368,37 @@ public void InvokeOnMainThread(Action task) InvokeOnMainThread(() => { task(); return true; }); } + private Task InvokeOnMainThreadAsync(Func task) + { + if (!InvokeRequired) + { + try + { + return Task.FromResult(task()); + } + catch (Exception e) + { + return Task.FromException(e); + } + } + + TaskWithResult taskWithResult = new(task); + lock (threadTasksLock) + { + threadTasks.Enqueue(taskWithResult); + } + return taskWithResult.AsTask(); + } + + private Task InvokeOnMainThreadAsync(Action task) + { + return InvokeOnMainThreadAsync(() => + { + task(); + return true; + }); + } + /// /// Clear all tasks /// @@ -1133,7 +1406,8 @@ public void ClearTasks() { lock (threadTasksLock) { - threadTasks.Clear(); + while (threadTasks.Count > 0) + threadTasks.Dequeue().Cancel(); } } @@ -1145,16 +1419,13 @@ public bool InvokeRequired { get { - int callingThreadId = Environment.CurrentManagedThreadId; - if (handler is not null) - { - return handler.GetNetMainThreadId() != callingThreadId; - } - else + if (handler is null) { // net read thread (main thread) not yet ready return false; } + + return !MainThreadExecutionScope.IsActive(this); } } @@ -3040,6 +3311,7 @@ public void OnGameJoined(bool isOnlineMode) DispatchBotEvent(bot => bot.AfterGameJoined()); + BeginCommandListInitialization(); ConsoleIO.InitCommandList(dispatcher); } @@ -4445,6 +4717,8 @@ public void OnBlockEntityData(Location location, Dictionary? nbt public void OnAutoCompleteDone(int transactionId, string[] result) { ConsoleIO.OnAutoCompleteDone(transactionId, result); + CompletePendingNetworkAutoComplete(result); + CompletePendingCommandListInitialization(); } public void SetCanSendMessage(bool canSendMessage) diff --git a/MinecraftClient/Mcp/IMccMcpCapabilities.cs b/MinecraftClient/Mcp/IMccMcpCapabilities.cs index 055fb56146..1054bfb6a6 100644 --- a/MinecraftClient/Mcp/IMccMcpCapabilities.cs +++ b/MinecraftClient/Mcp/IMccMcpCapabilities.cs @@ -1,3 +1,5 @@ +using System.Threading.Tasks; + namespace MinecraftClient.Mcp; public interface IMccMcpCapabilities @@ -43,7 +45,9 @@ public interface IMccMcpCapabilities MccMcpResult FindNearestEntity(string? typeFilter, string? nameFilter, double radius, bool includePlayers); MccMcpResult CanReachPosition(double x, double y, double z, bool allowUnsafe, int maxOffset, int minOffset, int timeoutMs); MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs); + Task MoveToAsync(double x, double y, double z, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs); MccMcpResult MoveToPlayer(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs); + Task MoveToPlayerAsync(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs); MccMcpResult LookAt(double x, double y, double z); MccMcpResult LookDirection(string direction); MccMcpResult LookAngles(float yaw, float pitch); @@ -51,7 +55,9 @@ public interface IMccMcpCapabilities MccMcpResult GetInventorySnapshot(int inventoryId); MccMcpResult SearchInventories(string query, int maxCount, bool exactMatch, bool includeContainers); MccMcpResult OpenContainerAt(int x, int y, int z, int timeoutMs, bool closeCurrent); + Task OpenContainerAtAsync(int x, int y, int z, int timeoutMs, bool closeCurrent); MccMcpResult CloseContainer(int inventoryId, int timeoutMs); + Task CloseContainerAsync(int inventoryId, int timeoutMs); MccMcpResult InventoryWindowAction(int inventoryId, int slotId, string actionType); MccMcpResult DropInventoryItem(string itemType, int count, int inventoryId, bool preferStack); MccMcpResult DepositContainerItem(string itemType, int count, int inventoryId, bool preferLargestStack); @@ -62,5 +68,6 @@ public interface IMccMcpCapabilities MccMcpResult FindSigns(string text, bool exactMatch, int radius, int maxCount, bool includeBackText); MccMcpResult ListItemEntities(string? itemType, double radius, int maxCount); MccMcpResult PickupItems(string itemType, double radius, int maxItems, bool allowUnsafe, int timeoutMs); + Task PickupItemsAsync(string itemType, double radius, int maxItems, bool allowUnsafe, int timeoutMs); MccMcpResult GetWorldBlockAt(int x, int y, int z); } diff --git a/MinecraftClient/Mcp/MccEmbeddedMcpHost.cs b/MinecraftClient/Mcp/MccEmbeddedMcpHost.cs index a40b718342..b6a9eb7021 100644 --- a/MinecraftClient/Mcp/MccEmbeddedMcpHost.cs +++ b/MinecraftClient/Mcp/MccEmbeddedMcpHost.cs @@ -1,4 +1,6 @@ using System; +using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; @@ -12,7 +14,7 @@ public sealed class MccEmbeddedMcpHost { private readonly MccMcpConfig config; private readonly IMccMcpCapabilities capabilities; - private readonly object stateLock = new(); + private readonly SemaphoreSlim stateLock = new(1, 1); private WebApplication? app; public MccEmbeddedMcpHost(MccMcpConfig config, IMccMcpCapabilities capabilities) @@ -25,10 +27,7 @@ public bool IsRunning { get { - lock (stateLock) - { - return app is not null; - } + return app is not null; } } @@ -36,29 +35,37 @@ public bool IsRunning public bool Start(out string? error) { - lock (stateLock) + (bool success, string? startError) = StartAsync().GetAwaiter().GetResult(); + error = startError; + return success; + } + + public bool Stop(out string? error) + { + (bool success, string? stopError) = StopAsync().GetAwaiter().GetResult(); + error = stopError; + return success; + } + + public async Task<(bool Success, string? Error)> StartAsync(CancellationToken cancellationToken = default) + { + await stateLock.WaitAsync(cancellationToken); + try { - error = null; if (app is not null) - return true; + return (true, null); string route = NormalizeRoute(config.Transport.Route); string bindHost = string.IsNullOrWhiteSpace(config.Transport.BindHost) ? "127.0.0.1" : config.Transport.BindHost.Trim(); if (config.Transport.Port is < 1 or > 65535) - { - error = "invalid_port"; - return false; - } + return (false, "invalid_port"); string? requiredToken = null; if (config.Transport.RequireAuthToken) { requiredToken = Environment.GetEnvironmentVariable(config.Transport.AuthTokenEnvVar); if (string.IsNullOrWhiteSpace(requiredToken)) - { - error = "missing_auth_token"; - return false; - } + return (false, "missing_auth_token"); } WebApplicationBuilder builder = WebApplication.CreateBuilder(); @@ -96,33 +103,49 @@ public bool Start(out string? error) } builtApp.MapMcp(route); - builtApp.StartAsync().GetAwaiter().GetResult(); - app = builtApp; - return true; + + try + { + await builtApp.StartAsync(cancellationToken); + app = builtApp; + return (true, null); + } + catch + { + await builtApp.DisposeAsync(); + throw; + } + } + finally + { + stateLock.Release(); } } - public bool Stop(out string? error) + public async Task<(bool Success, string? Error)> StopAsync(CancellationToken cancellationToken = default) { - lock (stateLock) + await stateLock.WaitAsync(cancellationToken); + try { - error = null; if (app is null) - return true; + return (true, null); try { - app.StopAsync().GetAwaiter().GetResult(); - app.DisposeAsync().AsTask().GetAwaiter().GetResult(); + await app.StopAsync(cancellationToken); + await app.DisposeAsync(); app = null; - return true; + return (true, null); } catch { - error = "stop_failed"; - return false; + return (false, "stop_failed"); } } + finally + { + stateLock.Release(); + } } private static string NormalizeRoute(string route) diff --git a/MinecraftClient/Mcp/MccMcpCapabilities.cs b/MinecraftClient/Mcp/MccMcpCapabilities.cs index 60a03f767d..a1127d184a 100644 --- a/MinecraftClient/Mcp/MccMcpCapabilities.cs +++ b/MinecraftClient/Mcp/MccMcpCapabilities.cs @@ -71,6 +71,8 @@ private sealed class NearbyItemSnapshot public required double Distance { get; init; } } + private readonly record struct ContainerOpenState(int InventoryId, Container? Inventory); + private enum InventoryTransferDirection { Deposit, @@ -1274,6 +1276,11 @@ public MccMcpResult FindNearestEntity(string? typeFilter, string? nameFilter, do } public MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) + { + return MoveToAsync(x, y, z, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs).GetAwaiter().GetResult(); + } + + public async Task MoveToAsync(double x, double y, double z, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) { if (!IsCategoryEnabled(t => t.Movement)) return MccMcpResult.Fail("capability_disabled"); @@ -1302,9 +1309,12 @@ public MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool int verifyWaitMs = GetArrivalWaitMs(timeoutMs); double tolerance = GetArrivalTolerance(maxOffset, minOffset); - Location? finalLocation = null; - bool arrived = pathFound && WaitForArrival(client, goal, verifyWaitMs, tolerance, out finalLocation); - finalLocation ??= client.InvokeOnMainThread(client.GetCurrentLocation); + Location finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + bool arrived = false; + if (pathFound) + { + (arrived, finalLocation) = await WaitForArrivalAsync(client, goal, verifyWaitMs, tolerance); + } object resultData = new { pathFound, @@ -1313,9 +1323,9 @@ public MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool verifyWaitMs, target = ToCoordinate(goal), startLocation = ToCoordinate(startLocation), - finalLocation = ToCoordinate(finalLocation.Value), - finalDistance = GetDistance(finalLocation.Value, goal), - distanceMoved = GetDistance(startLocation, finalLocation.Value), + finalLocation = ToCoordinate(finalLocation), + finalDistance = GetDistance(finalLocation, goal), + distanceMoved = GetDistance(startLocation, finalLocation), allowUnsafe, allowDirectTeleport, maxOffset, @@ -1329,10 +1339,15 @@ public MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool } public MccMcpResult MoveToPlayer(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) + { + return MoveToPlayerAsync(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs).GetAwaiter().GetResult(); + } + + public async Task MoveToPlayerAsync(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) { if (!IsCategoryEnabled(t => t.Movement)) return MccMcpResult.Fail("capability_disabled"); - return ToMcpResult(game.MoveToPlayer(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs)); + return ToMcpResult(await game.MoveToPlayerAsync(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs)); } public MccMcpResult LookAt(double x, double y, double z) @@ -1464,6 +1479,11 @@ public MccMcpResult ListInventories() } public MccMcpResult OpenContainerAt(int x, int y, int z, int timeoutMs, bool closeCurrent) + { + return OpenContainerAtAsync(x, y, z, timeoutMs, closeCurrent).GetAwaiter().GetResult(); + } + + public async Task OpenContainerAtAsync(int x, int y, int z, int timeoutMs, bool closeCurrent) { if (!IsCategoryEnabled(t => t.Inventory)) return MccMcpResult.Fail("capability_disabled"); @@ -1495,10 +1515,15 @@ public MccMcpResult OpenContainerAt(int x, int y, int z, int timeoutMs, bool clo }); } - return OpenContainerCore(client, location, state.block, state.activeContainerId, waitMs, closeCurrent); + return await OpenContainerCoreAsync(client, location, state.block, state.activeContainerId, waitMs, closeCurrent); } public MccMcpResult CloseContainer(int inventoryId, int timeoutMs) + { + return CloseContainerAsync(inventoryId, timeoutMs).GetAwaiter().GetResult(); + } + + public async Task CloseContainerAsync(int inventoryId, int timeoutMs) { if (!IsCategoryEnabled(t => t.Inventory)) return MccMcpResult.Fail("capability_disabled"); @@ -1527,7 +1552,7 @@ public MccMcpResult CloseContainer(int inventoryId, int timeoutMs) } bool closeAccepted = client.CloseInventory(resolvedInventoryId); - bool closed = closeAccepted && WaitForContainerClose(client, resolvedInventoryId, waitMs); + bool closed = closeAccepted && await WaitForContainerCloseAsync(client, resolvedInventoryId, waitMs); var resultData = new { success = closeAccepted && closed, @@ -1824,10 +1849,15 @@ public MccMcpResult ListItemEntities(string? itemType, double radius, int maxCou } public MccMcpResult PickupItems(string itemType, double radius, int maxItems, bool allowUnsafe, int timeoutMs) + { + return PickupItemsAsync(itemType, radius, maxItems, allowUnsafe, timeoutMs).GetAwaiter().GetResult(); + } + + public async Task PickupItemsAsync(string itemType, double radius, int maxItems, bool allowUnsafe, int timeoutMs) { if (!IsCategoryEnabled(t => t.EntityWorld) || !IsCategoryEnabled(t => t.Movement)) return MccMcpResult.Fail("capability_disabled"); - return ToMcpResult(game.PickupItems(itemType, radius, maxItems, allowUnsafe, timeoutMs)); + return ToMcpResult(await game.PickupItemsAsync(itemType, radius, maxItems, allowUnsafe, timeoutMs)); } public MccMcpResult GetWorldBlockAt(int x, int y, int z) @@ -1858,7 +1888,7 @@ public MccMcpResult GetWorldBlockAt(int x, int y, int z) }); } - private static MccMcpResult OpenContainerCore(McClient client, Location location, Block block, int activeContainerId, int waitMs, bool closeCurrent) + private static async Task OpenContainerCoreAsync(McClient client, Location location, Block block, int activeContainerId, int waitMs, bool closeCurrent) { if (activeContainerId > 0) { @@ -1876,7 +1906,7 @@ private static MccMcpResult OpenContainerCore(McClient client, Location location } bool closeAccepted = client.CloseInventory(activeContainerId); - bool closed = closeAccepted && WaitForContainerClose(client, activeContainerId, waitMs); + bool closed = closeAccepted && await WaitForContainerCloseAsync(client, activeContainerId, waitMs); if (!closeAccepted || !closed) { return MccMcpResult.Fail("action_incomplete", data: new @@ -1894,7 +1924,11 @@ private static MccMcpResult OpenContainerCore(McClient client, Location location int openedInventoryId = 0; Container? openedInventory = null; bool openAccepted = client.InvokeOnMainThread(() => client.PlaceBlock(location, Direction.Down, Hand.MainHand, lookAtBlock: true)); - bool opened = openAccepted && WaitForContainerOpen(client, beforeIds, waitMs, out openedInventoryId, out openedInventory); + bool opened = openAccepted && await WaitForContainerOpenAsync(client, beforeIds, waitMs, result => + { + openedInventoryId = result.InventoryId; + openedInventory = result.Inventory; + }); var resultData = new { success = openAccepted && opened && openedInventory is not null, @@ -2233,6 +2267,31 @@ private static bool WaitForContainerOpen(McClient client, ISet beforeIds, i } } + private static async Task WaitForContainerOpenAsync(McClient client, ISet beforeIds, int waitMs, Action onOpened) + { + DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); + while (true) + { + (int activeId, Container? activeInventory) state = client.InvokeOnMainThread(() => + { + int activeId = GetActiveContainerId(client); + Container? activeInventory = activeId > 0 ? client.GetInventory(activeId) : null; + return (activeId, activeInventory); + }); + + if (state.activeId > 0 && (!beforeIds.Contains(state.activeId) || beforeIds.Count == 0) && state.activeInventory is not null) + { + onOpened(new ContainerOpenState(state.activeId, state.activeInventory)); + return true; + } + + if (DateTime.UtcNow >= deadline) + return false; + + await Task.Delay(ArrivalPollIntervalMs); + } + } + private static bool WaitForContainerClose(McClient client, int inventoryId, int waitMs) { DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); @@ -2249,6 +2308,22 @@ private static bool WaitForContainerClose(McClient client, int inventoryId, int } } + private static async Task WaitForContainerCloseAsync(McClient client, int inventoryId, int waitMs) + { + DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); + while (true) + { + bool stillOpen = client.InvokeOnMainThread(() => client.GetInventories().ContainsKey(inventoryId)); + if (!stillOpen) + return true; + + if (DateTime.UtcNow >= deadline) + return false; + + await Task.Delay(ArrivalPollIntervalMs); + } + } + private static int GetContainerWaitMs(int timeoutMs) { if (timeoutMs <= 0) @@ -2901,6 +2976,24 @@ private static bool WaitForArrival(McClient client, Location goal, int waitMs, d } } + private static async Task<(bool Arrived, Location FinalLocation)> WaitForArrivalAsync(McClient client, Location goal, int waitMs, double tolerance) + { + DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); + Location finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + while (true) + { + finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + double distance = GetDistance(finalLocation, goal); + if (distance <= tolerance) + return (true, finalLocation); + + if (DateTime.UtcNow >= deadline) + return (false, finalLocation); + + await Task.Delay(ArrivalPollIntervalMs); + } + } + private static double GetDistance(Location from, Location to) { double dx = from.X - to.X; diff --git a/MinecraftClient/Mcp/MccMcpToolSet.cs b/MinecraftClient/Mcp/MccMcpToolSet.cs index 84305e3480..46fe955ea0 100644 --- a/MinecraftClient/Mcp/MccMcpToolSet.cs +++ b/MinecraftClient/Mcp/MccMcpToolSet.cs @@ -1,4 +1,5 @@ using System.ComponentModel; +using System.Threading.Tasks; using ModelContextProtocol.Server; namespace MinecraftClient.Mcp; @@ -262,15 +263,15 @@ public object CanReachPosition(double x, double y, double z, bool allowUnsafe = } [McpServerTool(Name = "mcc_move_to"), Description("Request movement/pathing to a world coordinate and verify arrival.")] - public object MoveTo(double x, double y, double z, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) + public async Task MoveTo(double x, double y, double z, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) { - return capabilities.MoveTo(x, y, z, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs); + return await capabilities.MoveToAsync(x, y, z, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs); } [McpServerTool(Name = "mcc_move_to_player"), Description("Locate a tracked player entity, request movement/pathing, and verify arrival.")] - public object MoveToPlayer(string playerName, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) + public async Task MoveToPlayer(string playerName, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) { - return capabilities.MoveToPlayer(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs); + return await capabilities.MoveToPlayerAsync(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs); } [McpServerTool(Name = "mcc_look_at"), Description("Rotate player view toward world coordinates.")] @@ -310,15 +311,15 @@ public object InventoriesList() } [McpServerTool(Name = "mcc_container_open_at"), Description("Open an interactable container block at world coordinates and wait for the container inventory to appear.")] - public object ContainerOpenAt(int x, int y, int z, int timeoutMs = 0, bool closeCurrent = true) + public async Task ContainerOpenAt(int x, int y, int z, int timeoutMs = 0, bool closeCurrent = true) { - return capabilities.OpenContainerAt(x, y, z, timeoutMs, closeCurrent); + return await capabilities.OpenContainerAtAsync(x, y, z, timeoutMs, closeCurrent); } [McpServerTool(Name = "mcc_container_close"), Description("Close an open non-player container. Use inventoryId=-1 to close the active container.")] - public object ContainerClose([Description("Container inventory ID, or -1 for the active non-player container.")] int inventoryId = -1, int timeoutMs = 0) + public async Task ContainerClose([Description("Container inventory ID, or -1 for the active non-player container.")] int inventoryId = -1, int timeoutMs = 0) { - return capabilities.CloseContainer(inventoryId, timeoutMs); + return await capabilities.CloseContainerAsync(inventoryId, timeoutMs); } [McpServerTool(Name = "mcc_inventory_window_action"), Description("Perform a window action on an inventory slot.")] @@ -388,9 +389,9 @@ public object ItemsList(string? itemType = null, double radius = 32, int maxCoun } [McpServerTool(Name = "mcc_items_pickup"), Description("Move to and pick up nearby dropped items of a given item type.")] - public object ItemsPickup(string itemType, double radius = 32, int maxItems = 20, bool allowUnsafe = false, int timeoutMs = 0) + public async Task ItemsPickup(string itemType, double radius = 32, int maxItems = 20, bool allowUnsafe = false, int timeoutMs = 0) { - return capabilities.PickupItems(itemType, radius, maxItems, allowUnsafe, timeoutMs); + return await capabilities.PickupItemsAsync(itemType, radius, maxItems, allowUnsafe, timeoutMs); } [McpServerTool(Name = "mcc_world_block_at"), Description("Get block information at world coordinates.")] diff --git a/MinecraftClient/Program.cs b/MinecraftClient/Program.cs index 58a42b9ba2..cfdb84fc8d 100644 --- a/MinecraftClient/Program.cs +++ b/MinecraftClient/Program.cs @@ -74,7 +74,7 @@ internal sealed class StartupState /// /// The main entry point of Minecraft Console Client /// - static void Main(string[] args) + static async Task Main(string[] args) { // [SENTRY] Initialize Sentry SDK only if the DSN is not empty if (SentryDSN != string.Empty) @@ -94,7 +94,7 @@ static void Main(string[] args) }; } - Task.Run(() => + _ = Task.Run(() => { // "ToLower" require "CultureInfo" to be initialized on first run, which can take a lot of time. _ = "a".ToLower(); @@ -208,7 +208,7 @@ static void Main(string[] args) // Wait for this issue to be fixed before enabling it: https://github.com/Consolonia/Consolonia/issues/602 // MaybePrintClassicModeTuiRecommendation(); - RunStartupSequence(args); + await RunStartupSequenceAsync(args); } /// @@ -381,7 +381,10 @@ internal static void HandleConfigLoadFailure() /// Called from Main() for classic/basic mode, or from TuiConsoleBackend on a /// background thread after the Avalonia UI loop has started. /// - internal static void RunStartupSequence(string[] args) + internal static void RunStartupSequence(string[] args) => + RunStartupSequenceAsync(args).GetAwaiter().GetResult(); + + internal static async Task RunStartupSequenceAsync(string[] args) { //Other command-line arguments if (args.Length >= 1) @@ -574,20 +577,20 @@ internal static void RunStartupSequence(string[] args) if (string.IsNullOrWhiteSpace(InternalConfig.Account.Password) && !skipPassword && (Config.Main.Advanced.SessionCache == CacheType.none || !SessionCache.Contains(ToLowerIfNeed(InternalConfig.Account.Login)))) { - RequestPassword(); + await RequestPasswordAsync(); } startupargs = args; - InitializeClient(); + await InitializeClientAsync(); } /// /// Reduest user to submit password. /// - private static void RequestPassword() + private static async Task RequestPasswordAsync() { ConsoleIO.WriteLine(ConsoleIO.BasicIO ? string.Format(Translations.mcc_password_basic_io, InternalConfig.Account.Login) + "\n" : Translations.mcc_password_hidden); - string? password = ConsoleIO.BasicIO ? Console.ReadLine() : ConsoleIO.ReadPassword(); + string? password = await ConsoleIO.ReadPasswordAsync(); if (string.IsNullOrWhiteSpace(password)) InternalConfig.Account.Password = "-"; else @@ -597,7 +600,7 @@ private static void RequestPassword() /// /// Start a new Client /// - private static void InitializeClient() + private static async Task InitializeClientAsync() { // Ensure that we use the provided Minecraft version if we can't connect automatically. // @@ -634,7 +637,9 @@ private static void InitializeClient() { try { - result = ProtocolHandler.MicrosoftLoginRefresh(session.RefreshToken, out session); + var refreshResult = await ProtocolHandler.MicrosoftLoginRefreshAsync(session.RefreshToken); + result = refreshResult.Result; + session = refreshResult.Session; } catch (Exception ex) { @@ -646,7 +651,7 @@ private static void InitializeClient() if (result != ProtocolHandler.LoginResult.Success && string.IsNullOrWhiteSpace(InternalConfig.Account.Password) && !(Config.Main.General.AccountType == LoginType.microsoft)) - RequestPassword(); + await RequestPasswordAsync(); } else ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.mcc_session_valid, session.PlayerName)); } @@ -654,14 +659,16 @@ private static void InitializeClient() if (result != ProtocolHandler.LoginResult.Success) { ConsoleIO.WriteLine(string.Format(Translations.mcc_connecting, Config.Main.General.AccountType == LoginType.mojang ? "Minecraft.net" : (Config.Main.General.AccountType == LoginType.microsoft ? "Microsoft" : Config.Main.General.AuthServer.Host))); - result = ProtocolHandler.GetLogin(InternalConfig.Account.Login, InternalConfig.Account.Password, Config.Main.General.AccountType, out session); + var loginResult = await ProtocolHandler.GetLoginAsync(InternalConfig.Account.Login, InternalConfig.Account.Password, Config.Main.General.AccountType); + result = loginResult.Result; + session = loginResult.Session; } if (result == ProtocolHandler.LoginResult.Success && Config.Main.Advanced.SessionCache != CacheType.none) SessionCache.Store(loginLower, session); if (result == ProtocolHandler.LoginResult.Success) - session.SessionPreCheckTask = Task.Factory.StartNew(() => session.SessionPreCheck(Config.Main.General.AccountType)); + session.SessionPreCheckTask = session.SessionPreCheckAsync(Config.Main.General.AccountType); } if (result == ProtocolHandler.LoginResult.Success) @@ -680,7 +687,7 @@ private static void InitializeClient() List availableWorlds = new(); if (Config.Main.Advanced.MinecraftRealms && !String.IsNullOrEmpty(session.ID)) - availableWorlds = ProtocolHandler.RealmsListWorlds(InternalConfig.Username, session.PlayerID, session.ID); + availableWorlds = await ProtocolHandler.RealmsListWorldsAsync(InternalConfig.Username, session.PlayerID, session.ID); if (InternalConfig.ServerIP == string.Empty) { @@ -700,7 +707,7 @@ private static void InitializeClient() worldId = availableWorlds[worldIndex]; if (availableWorlds.Contains(worldId)) { - string realmsAddress = ProtocolHandler.GetRealmsWorldServerAddress(worldId, InternalConfig.Username, session.PlayerID, session.ID); + string realmsAddress = await ProtocolHandler.GetRealmsWorldServerAddressAsync(worldId, InternalConfig.Username, session.PlayerID, session.ID); if (realmsAddress != "") { addressInput = realmsAddress; @@ -756,11 +763,15 @@ private static void InitializeClient() ConsoleIO.WriteLine(Translations.mcc_forge); else ConsoleIO.WriteLine(Translations.mcc_retrieve); - if (!ProtocolHandler.GetServerInfo(InternalConfig.ServerIP, InternalConfig.ServerPort, ref protocolversion, ref forgeInfo)) + var serverInfo = await ProtocolHandler.GetServerInfoAsync(InternalConfig.ServerIP, InternalConfig.ServerPort, protocolversion); + if (!serverInfo.Success) { HandleFailure(Translations.error_ping, true, ChatBot.DisconnectReason.ConnectionLost); return; } + + protocolversion = serverInfo.ProtocolVersion; + forgeInfo = serverInfo.ForgeInfo; } if ((Config.Main.General.AccountType == LoginType.microsoft || Config.Main.General.AccountType == LoginType.yggdrasil) @@ -890,30 +901,41 @@ public static void WriteBackSettings(bool enableBackup = true) /// /// Optional delay, in seconds, before restarting /// Optional, keep account and server settings - public static void Restart(int delaySeconds = 0, bool keepAccountAndServerSettings = false) + public static void Restart(int delaySeconds = 0, bool keepAccountAndServerSettings = false, bool announceDelay = true) { ConsoleIO.Backend?.StopReadThread(); - new Thread(new ThreadStart(delegate + StartLifecycleTask(RestartAsync(delaySeconds, keepAccountAndServerSettings, announceDelay)); + } + + private static async Task RestartAsync(int delaySeconds, bool keepAccountAndServerSettings, bool announceDelay) + { + if (client is not null) { client.Disconnect(); ConsoleIO.Reset(); } + if (offlinePrompt is not null) { - if (client is not null) { client.Disconnect(); ConsoleIO.Reset(); } - if (offlinePrompt is not null) - { - if (ConsoleIO.Backend is not null) - ConsoleIO.Backend.OnInputChange -= ConsoleIO.OfflineAutocompleteHandler; - offlinePrompt.Item2.Cancel(); offlinePrompt.Item1.Join(); offlinePrompt = null; ConsoleIO.Reset(); - } - if (delaySeconds > 0) - { + if (ConsoleIO.Backend is not null) + ConsoleIO.Backend.OnInputChange -= ConsoleIO.OfflineAutocompleteHandler; + offlinePrompt.Item2.Cancel(); + offlinePrompt.Item1.Join(); + offlinePrompt = null; + ConsoleIO.Reset(); + } + if (delaySeconds > 0) + { + if (announceDelay) ConsoleIO.WriteLine(string.Format(Translations.mcc_restart_delay, delaySeconds)); - Thread.Sleep(delaySeconds * 1000); - } - ConsoleIO.WriteLine(Translations.mcc_restart); - ReloadSettings(keepAccountAndServerSettings); - InitializeClient(); - })).Start(); + await Task.Delay(TimeSpan.FromSeconds(delaySeconds)); + } + ConsoleIO.WriteLine(Translations.mcc_restart); + ReloadSettings(keepAccountAndServerSettings); + await InitializeClientAsync(); } public static void DoExit(int exitcode = 0) + { + DoExitAsync(exitcode).GetAwaiter().GetResult(); + } + + private static Task DoExitAsync(int exitcode = 0) { WriteBackSettings(); ConsoleIO.WriteLineFormatted("§a" + string.Format(Translations.config_saving, settingsIniPath)); @@ -932,6 +954,7 @@ public static void DoExit(int exitcode = 0) if (Config.Main.Advanced.PlayerHeadAsIcon) { ConsoleIcon.RevertToMCCIcon(); } ConsoleIO.Backend?.Shutdown(); Environment.Exit(exitcode); + return Task.CompletedTask; } /// @@ -939,7 +962,26 @@ public static void DoExit(int exitcode = 0) /// public static void Exit(int exitcode = 0) { - new Thread(() => { DoExit(exitcode); }).Start(); + StartLifecycleTask(DoExitAsync(exitcode)); + } + + private static void StartLifecycleTask(Task lifecycleTask) + { + _ = ObserveLifecycleTaskAsync(lifecycleTask); + } + + private static async Task ObserveLifecycleTaskAsync(Task lifecycleTask) + { + try + { + await lifecycleTask; + } + catch (Exception ex) + { + SentrySdk.CaptureException(ex); + if (Settings.Config.Logging.DebugMessages) + ConsoleIO.WriteLineFormatted("§8" + ex); + } } /// diff --git a/MinecraftClient/Protocol/Handlers/DataTypes.cs b/MinecraftClient/Protocol/Handlers/DataTypes.cs index bdbdecea39..68a842e394 100644 --- a/MinecraftClient/Protocol/Handlers/DataTypes.cs +++ b/MinecraftClient/Protocol/Handlers/DataTypes.cs @@ -2,6 +2,8 @@ using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Text; +using System.Threading; +using System.Threading.Tasks; using MinecraftClient.Inventory; using MinecraftClient.Inventory.ItemPalettes; using MinecraftClient.Mapping; @@ -304,7 +306,28 @@ public int ReadNextVarIntRAW(SocketWrapper socket) byte b; while (true) { - b = socket.ReadDataRAW(1)[0]; + b = socket.ReadByteRAW(); + i |= (b & 0x7F) << j++ * 7; + if (j > 5) throw new OverflowException("VarInt too big"); + if ((b & 0x80) != 128) break; + } + + return i; + } + + /// + /// Read an integer from the network asynchronously. + /// + /// The integer + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public async Task ReadNextVarIntRAWAsync(SocketWrapper socket, CancellationToken cancellationToken) + { + int i = 0; + int j = 0; + byte b; + while (true) + { + b = await socket.ReadByteRAWAsync(cancellationToken); i |= (b & 0x7F) << j++ * 7; if (j > 5) throw new OverflowException("VarInt too big"); if ((b & 0x80) != 128) break; diff --git a/MinecraftClient/Protocol/Handlers/Protocol16.cs b/MinecraftClient/Protocol/Handlers/Protocol16.cs index 6777200d03..23741df0a3 100644 --- a/MinecraftClient/Protocol/Handlers/Protocol16.cs +++ b/MinecraftClient/Protocol/Handlers/Protocol16.cs @@ -7,6 +7,7 @@ using System.Security.Cryptography; using System.Text; using System.Threading; +using System.Threading.Tasks; using MinecraftClient.Crypto; using MinecraftClient.Inventory; using MinecraftClient.Mapping; @@ -29,7 +30,9 @@ class Protocol16Handler : IMinecraftCom readonly IMinecraftComHandler handler; private bool encrypted = false; private readonly int protocolversion; - private Tuple? netRead = null; + private Task? netReadTask; + private CancellationTokenSource? netReadCancellationTokenSource; + private int netReadThreadId = -1; Crypto.AesCfb8Stream? s; readonly TcpClient c; @@ -69,15 +72,15 @@ private Protocol16Handler(TcpClient Client) c = Client; } - private void Updater(object? o) + private async Task UpdaterAsync(CancellationToken cancelToken) { - var cancelToken = (CancellationToken)o!; - if (cancelToken.IsCancellationRequested) return; try { + netReadThreadId = Environment.CurrentManagedThreadId; + using IDisposable _ = MainThreadExecutionScope.Enter(handler); Stopwatch stopWatch = Stopwatch.StartNew(); long nextUpdateDue = 0; @@ -97,13 +100,15 @@ private void Updater(object? o) long sleepLength = nextUpdateDue - stopWatch.ElapsedMilliseconds; if (sleepLength > 1) - Thread.Sleep((int)Math.Min(sleepLength, ClientTickIntervalMilliseconds)); + await Task.Delay((int)Math.Min(sleepLength, ClientTickIntervalMilliseconds), cancelToken); } } catch (System.IO.IOException) { } catch (SocketException) { } catch (ObjectDisposedException) { } catch (OperationCanceledException) { } + catch (Exception) { } + finally { netReadThreadId = -1; } if (cancelToken.IsCancellationRequested) return; @@ -240,9 +245,9 @@ private bool ProcessPacket(byte id) private void StartUpdating() { - netRead = new(new Thread(new ParameterizedThreadStart(Updater)), new CancellationTokenSource()); - netRead.Item1.Name = "ProtocolPacketHandler"; - netRead.Item1.Start(netRead.Item2.Token); + CancellationTokenSource netReadCts = new(); + netReadCancellationTokenSource = netReadCts; + netReadTask = Task.Run(() => UpdaterAsync(netReadCts.Token), netReadCts.Token); } /// @@ -251,7 +256,7 @@ private void StartUpdating() /// Net read thread ID public int GetNetMainThreadId() { - return netRead is not null ? netRead.Item1.ManagedThreadId : -1; + return netReadThreadId; } public bool SendCookieResponse(string name, byte[]? data) @@ -268,9 +273,9 @@ public void Dispose() { try { - if (netRead is not null) + if (netReadCancellationTokenSource is not null) { - netRead.Item2.Cancel(); + netReadCancellationTokenSource.Cancel(); c.Close(); } } @@ -519,7 +524,8 @@ private bool Handshake(string uuid, string username, string sessionID, string ho Receive(pid, 0, 1, SocketFlags.None); while (pid[0] == 0xFA) //Skip some early plugin messages { - ProcessPacket(pid[0]); + using (MainThreadExecutionScope.Enter(handler)) + ProcessPacket(pid[0]); Receive(pid, 0, 1, SocketFlags.None); } if (pid[0] == 0xFD) @@ -559,8 +565,7 @@ private bool StartEncryption(string uuid, string username, string sessionID, Log if (session.ServerPublicKey is not null && session.SessionPreCheckTask is not null && serverIDhash == session.ServerIDhash && Enumerable.SequenceEqual(serverPublicKey, session.ServerPublicKey)) { - session.SessionPreCheckTask.Wait(); - if (session.SessionPreCheckTask.Result) // PreCheck Successed + if (session.SessionPreCheckTask.IsCompletedSuccessfully && session.SessionPreCheckTask.Result) needCheckSession = false; } @@ -633,7 +638,8 @@ public bool Login(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTra Receive(pid, 0, 1, SocketFlags.None); while (pid[0] >= 0xC0 && pid[0] != 0xFF) //Skip some early packets or plugin messages { - ProcessPacket(pid[0]); + using (MainThreadExecutionScope.Enter(handler)) + ProcessPacket(pid[0]); Receive(pid, 0, 1, SocketFlags.None); } if (pid[0] == (byte)1) diff --git a/MinecraftClient/Protocol/Handlers/Protocol18.cs b/MinecraftClient/Protocol/Handlers/Protocol18.cs index 21d37c887e..f554a35acb 100644 --- a/MinecraftClient/Protocol/Handlers/Protocol18.cs +++ b/MinecraftClient/Protocol/Handlers/Protocol18.cs @@ -9,6 +9,7 @@ using System.Text; using System.Text.RegularExpressions; using System.Threading; +using System.Threading.Tasks; using MinecraftClient.Crypto; using MinecraftClient.Inventory; using MinecraftClient.Inventory.ItemPalettes; @@ -117,8 +118,11 @@ class Protocol18Handler : IMinecraftCom readonly PacketTypePalette packetPalette; readonly SocketWrapper socketWrapper; readonly DataTypes dataTypes; - Tuple? netMain = null; // main thread - Tuple? netReader = null; // reader thread + private Task? netMainTask; + private CancellationTokenSource? netMainCancellationTokenSource; + private int netMainThreadId = -1; + private Task? netReaderTask; + private CancellationTokenSource? netReaderCancellationTokenSource; readonly ILogger log; readonly RandomNumberGenerator randomGen; private bool legacyAchievementsInitialized; @@ -278,17 +282,17 @@ public Protocol18Handler(TcpClient Client, int protocolVersion, IMinecraftComHan } /// - /// Separate thread. Network reading loop. + /// Serialized packet/tick loop. /// - private void Updater(object? o) + private async Task UpdaterAsync(CancellationToken cancelToken) { - var cancelToken = (CancellationToken)o!; - if (cancelToken.IsCancellationRequested) return; try { + netMainThreadId = Environment.CurrentManagedThreadId; + using IDisposable _ = MainThreadExecutionScope.Enter(handler); Stopwatch stopWatch = Stopwatch.StartNew(); long nextUpdateDue = 0; while (!packetQueue.IsAddingCompleted) @@ -312,7 +316,7 @@ private void Updater(object? o) long sleepLength = nextUpdateDue - stopWatch.ElapsedMilliseconds; if (sleepLength > 1) - Thread.Sleep((int)Math.Min(sleepLength, ClientTickIntervalMilliseconds)); + await Task.Delay((int)Math.Min(sleepLength, ClientTickIntervalMilliseconds), cancelToken); } } catch (ObjectDisposedException) @@ -330,6 +334,13 @@ private void Updater(object? o) catch (System.IO.IOException) { } + catch (Exception) + { + } + finally + { + netMainThreadId = -1; + } if (cancelToken.IsCancellationRequested) return; @@ -340,20 +351,13 @@ private void Updater(object? o) /// /// Read and decompress packets. /// - internal void PacketReader(object? o) + internal async Task PacketReaderAsync(CancellationToken cancelToken) { - var cancelToken = (CancellationToken)o!; - while (socketWrapper.IsConnected() && !cancelToken.IsCancellationRequested) + while (!cancelToken.IsCancellationRequested) { try { - while (socketWrapper.HasDataAvailable()) - { - packetQueue.Add(ReadNextPacket(), cancelToken); - - if (cancelToken.IsCancellationRequested) - break; - } + packetQueue.Add(await ReadNextPacketAsync(cancelToken), cancelToken); } catch (OperationCanceledException) { @@ -375,11 +379,10 @@ internal void PacketReader(object? o) { break; } - - if (cancelToken.IsCancellationRequested) + catch (Exception) + { break; - - Thread.Sleep(10); + } } packetQueue.CompleteAdding(); @@ -392,23 +395,21 @@ internal void PacketReader(object? o) /// will contain raw packet Data internal Tuple> ReadNextPacket() { - var size = dataTypes.ReadNextVarIntRAW(socketWrapper); //Packet size - Queue packetData = new(socketWrapper.ReadDataRAW(size)); //Packet contents + var (packetId, packetData) = socketWrapper.GetNextPacket( + protocolVersion >= MC_1_8_Version ? compression_treshold : -1, + dataTypes); + if (handler.GetNetworkPacketCaptureEnabled()) + handler.OnNetworkPacket(packetId, packetData.ToList(), currentState == CurrentState.Login, true); - //Handle packet decompression - if (protocolVersion >= MC_1_8_Version - && compression_treshold >= 0) - { - var sizeUncompressed = dataTypes.ReadNextVarInt(packetData); - if (sizeUncompressed != 0) // != 0 means compressed, let's decompress - { - var toDecompress = packetData.ToArray(); - var uncompressed = ZlibUtils.Decompress(toDecompress, sizeUncompressed); - packetData = new Queue(uncompressed); - } - } + return new(packetId, packetData); + } - var packetId = dataTypes.ReadNextVarInt(packetData); // Packet ID + internal async Task>> ReadNextPacketAsync(CancellationToken cancellationToken) + { + var (packetId, packetData) = await socketWrapper.GetNextPacketAsync( + protocolVersion >= MC_1_8_Version ? compression_treshold : -1, + dataTypes, + cancellationToken); if (handler.GetNetworkPacketCaptureEnabled()) handler.OnNetworkPacket(packetId, packetData.ToList(), currentState == CurrentState.Login, true); @@ -3840,23 +3841,17 @@ private bool SkipRecipeBookSettings(Queue packetData) } /// - /// Start the updating thread. Should be called after login success. + /// Start the serialized packet/tick tasks. Should be called after login success. /// private void StartUpdating() { - Thread threadUpdater = new(new ParameterizedThreadStart(Updater)) - { - Name = "ProtocolPacketHandler" - }; - netMain = new Tuple(threadUpdater, new CancellationTokenSource()); - threadUpdater.Start(netMain.Item2.Token); + CancellationTokenSource netMainCts = new(); + netMainCancellationTokenSource = netMainCts; + netMainTask = Task.Run(() => UpdaterAsync(netMainCts.Token), netMainCts.Token); - Thread threadReader = new(new ParameterizedThreadStart(PacketReader)) - { - Name = "ProtocolPacketReader" - }; - netReader = new Tuple(threadReader, new CancellationTokenSource()); - threadReader.Start(netReader.Item2.Token); + CancellationTokenSource netReaderCts = new(); + netReaderCancellationTokenSource = netReaderCts; + netReaderTask = Task.Run(() => PacketReaderAsync(netReaderCts.Token), netReaderCts.Token); } /// @@ -3865,7 +3860,7 @@ private void StartUpdating() /// Net read thread ID public int GetNetMainThreadId() { - return netMain is not null ? netMain.Item1.ManagedThreadId : -1; + return netMainThreadId; } /// @@ -3875,14 +3870,14 @@ public void Dispose() { try { - if (netMain is not null) + if (netMainCancellationTokenSource is not null) { - netMain.Item2.Cancel(); + netMainCancellationTokenSource.Cancel(); } - if (netReader is not null) + if (netReaderCancellationTokenSource is not null) { - netReader.Item2.Cancel(); + netReaderCancellationTokenSource.Cancel(); socketWrapper.Disconnect(); } } @@ -4106,7 +4101,8 @@ public bool Login(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTra return true; //No need to check session or start encryption } default: - HandlePacket(packetId, packetData); + using (MainThreadExecutionScope.Enter(handler)) + HandlePacket(packetId, packetData); break; } } @@ -4133,8 +4129,7 @@ private bool StartEncryption(string uuid, string sessionID, LoginType type, byte && serverIDhash == session.ServerIDhash && serverPublicKey.SequenceEqual(session.ServerPublicKey)) { - session.SessionPreCheckTask.Wait(); - if (session.SessionPreCheckTask.Result) // PreCheck Success + if (session.SessionPreCheckTask.IsCompletedSuccessfully && session.SessionPreCheckTask.Result) needCheckSession = false; } @@ -4256,7 +4251,8 @@ private bool StartEncryption(string uuid, string sessionID, LoginType type, byte return true; } default: - HandlePacket(packetId, packetData); + using (MainThreadExecutionScope.Enter(handler)) + HandlePacket(packetId, packetData); break; } } @@ -4355,14 +4351,8 @@ public static bool DoPing(string host, int port, ref int protocolVersion, ref Fo var statusRequest = DataTypes.GetVarInt(0); socketWrapper.SendDataRAW(dataTypes.ConcatBytes(DataTypes.GetVarInt(statusRequest.Length), statusRequest)); - // Read Response length - var packetLength = dataTypes.ReadNextVarIntRAW(socketWrapper); - if (packetLength <= 0) - return false; - - // Read the Packet Id - var packetData = new Queue(socketWrapper.ReadDataRAW(packetLength)); - if (dataTypes.ReadNextVarInt(packetData) != 0x00) + var (statusPacketId, packetData) = socketWrapper.GetNextPacket(-1, dataTypes); + if (statusPacketId != 0x00) return false; // Get the Json data @@ -4441,15 +4431,11 @@ public static bool DoPing(string host, int port, ref int protocolVersion, ref Fo var pingRequest = dataTypes.ConcatBytes(DataTypes.GetVarInt(0x01), DataTypes.GetLong(pingPayload)); socketWrapper.SendDataRAW(dataTypes.ConcatBytes(DataTypes.GetVarInt(pingRequest.Length), pingRequest)); - packetLength = dataTypes.ReadNextVarIntRAW(socketWrapper); - if (packetLength > 0) + var (pongPacketId, pongPacketData) = socketWrapper.GetNextPacket(-1, dataTypes); + if (pongPacketId == 0x01) { - packetData = new Queue(socketWrapper.ReadDataRAW(packetLength)); - if (dataTypes.ReadNextVarInt(packetData) == 0x01) - { - long pongPayload = dataTypes.ReadNextLong(packetData); - pingMs = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() - pingPayload; - } + long pongPayload = dataTypes.ReadNextLong(pongPacketData); + pingMs = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() - pingPayload; } } catch diff --git a/MinecraftClient/Protocol/Handlers/SocketWrapper.cs b/MinecraftClient/Protocol/Handlers/SocketWrapper.cs index a4f451b134..c3ed8b1531 100644 --- a/MinecraftClient/Protocol/Handlers/SocketWrapper.cs +++ b/MinecraftClient/Protocol/Handlers/SocketWrapper.cs @@ -1,6 +1,12 @@ using System; +using System.Collections.Generic; +using System.IO; +using System.IO.Compression; using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; using MinecraftClient.Crypto; +using MinecraftClient.Protocol.PacketPipeline; namespace MinecraftClient.Protocol.Handlers { @@ -9,9 +15,14 @@ namespace MinecraftClient.Protocol.Handlers /// public class SocketWrapper { - readonly TcpClient c; - AesCfb8Stream? s; - bool encrypted = false; + private readonly TcpClient client; + private readonly Stream networkStream; + private readonly SemaphoreSlim sendSemaphore = new(1, 1); + private readonly byte[] singleByteBuffer = new byte[1]; + private AesCfb8Stream? encryptedStream; + private Stream readStream; + private Stream writeStream; + private bool encrypted = false; /// /// Initialize a new SocketWrapper @@ -19,7 +30,9 @@ public class SocketWrapper /// TcpClient connected to the server public SocketWrapper(TcpClient client) { - c = client; + this.client = client; + networkStream = client.GetStream(); + readStream = writeStream = networkStream; } /// @@ -29,7 +42,7 @@ public SocketWrapper(TcpClient client) /// Silently dropped connection can only be detected by attempting to read/write data public bool IsConnected() { - return c.Client is not null && c.Connected; + return client.Client is not null && client.Connected; } /// @@ -38,7 +51,7 @@ public bool IsConnected() /// TRUE if data is available to read public bool HasDataAvailable() { - return c.Client.Available > 0; + return client.Client.Available > 0; } /// @@ -49,23 +62,21 @@ public void SwitchToEncrypted(byte[] secretKey) { if (encrypted) throw new InvalidOperationException("Stream is already encrypted!?"); - s = new AesCfb8Stream(c.GetStream(), secretKey); + encryptedStream = new AesCfb8Stream(networkStream, secretKey); + readStream = writeStream = encryptedStream; encrypted = true; } - /// - /// Network reading method. Read bytes from the socket or encrypted socket. - /// - private void Receive(byte[] buffer, int start, int offset, SocketFlags f) + public byte ReadByteRAW() { - int read = 0; - while (read < offset) - { - if (encrypted) - read += s!.Read(buffer, start + read, offset - read); - else - read += c.Client.Receive(buffer, start + read, offset - read, f); - } + readStream.ReadExactly(singleByteBuffer); + return singleByteBuffer[0]; + } + + public async ValueTask ReadByteRAWAsync(CancellationToken cancellationToken) + { + await readStream.ReadExactlyAsync(singleByteBuffer.AsMemory(0, 1), cancellationToken); + return singleByteBuffer[0]; } /// @@ -77,13 +88,45 @@ public byte[] ReadDataRAW(int length) { if (length > 0) { - byte[] cache = new byte[length]; - Receive(cache, 0, length, SocketFlags.None); + byte[] cache = GC.AllocateUninitializedArray(length); + readStream.ReadExactly(cache); + return cache; + } + return Array.Empty(); + } + + public async Task ReadDataRAWAsync(int length, CancellationToken cancellationToken) + { + if (length > 0) + { + byte[] cache = GC.AllocateUninitializedArray(length); + await readStream.ReadExactlyAsync(cache.AsMemory(0, length), cancellationToken); return cache; } + return Array.Empty(); } + internal Tuple> GetNextPacket(int compressionThreshold, DataTypes dataTypes) + { + int packetLength = ReadNextVarIntRaw(); + using PacketReadStream packetStream = new(readStream, packetLength); + byte[] payload = ReadPacketPayload(packetStream, compressionThreshold); + Queue packetData = new(payload); + int packetId = dataTypes.ReadNextVarInt(packetData); + return new(packetId, packetData); + } + + internal async Task>> GetNextPacketAsync(int compressionThreshold, DataTypes dataTypes, CancellationToken cancellationToken) + { + int packetLength = await ReadNextVarIntRawAsync(cancellationToken); + await using PacketReadStream packetStream = new(readStream, packetLength); + byte[] payload = await ReadPacketPayloadAsync(packetStream, compressionThreshold, cancellationToken); + Queue packetData = new(payload); + int packetId = dataTypes.ReadNextVarInt(packetData); + return new(packetId, packetData); + } + /// /// Send raw data to the server. /// @@ -93,10 +136,33 @@ public void SendDataRAW(byte[] buffer) if (!IsConnected()) throw new SocketException((int)SocketError.NotConnected); - if (encrypted) - s!.Write(buffer, 0, buffer.Length); - else - c.Client.Send(buffer); + sendSemaphore.Wait(); + try + { + writeStream.Write(buffer, 0, buffer.Length); + writeStream.Flush(); + } + finally + { + sendSemaphore.Release(); + } + } + + public async Task SendDataRAWAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + if (!IsConnected()) + throw new SocketException((int)SocketError.NotConnected); + + await sendSemaphore.WaitAsync(cancellationToken); + try + { + await writeStream.WriteAsync(buffer, cancellationToken); + await writeStream.FlushAsync(cancellationToken); + } + finally + { + sendSemaphore.Release(); + } } /// @@ -106,12 +172,117 @@ public void Disconnect() { try { - c.Close(); + encryptedStream?.Dispose(); + client.Close(); } catch (SocketException) { } catch (System.IO.IOException) { } catch (NullReferenceException) { } catch (ObjectDisposedException) { } } + + private int ReadNextVarIntRaw() + { + int value = 0; + int position = 0; + + while (true) + { + byte current = ReadByteRAW(); + value |= (current & 0x7F) << position++ * 7; + if (position > 5) + throw new OverflowException("VarInt too big"); + if ((current & 0x80) != 0x80) + return value; + } + } + + private async Task ReadNextVarIntRawAsync(CancellationToken cancellationToken) + { + int value = 0; + int position = 0; + + while (true) + { + byte current = await ReadByteRAWAsync(cancellationToken); + value |= (current & 0x7F) << position++ * 7; + if (position > 5) + throw new OverflowException("VarInt too big"); + if ((current & 0x80) != 0x80) + return value; + } + } + + private static byte[] ReadPacketPayload(PacketReadStream packetStream, int compressionThreshold) + { + if (compressionThreshold >= 0) + { + int uncompressedLength = ReadNextVarIntRaw(packetStream); + if (uncompressedLength > 0) + { + using ZLibStream zlibStream = new(packetStream, CompressionMode.Decompress, leaveOpen: true); + byte[] payload = GC.AllocateUninitializedArray(uncompressedLength); + zlibStream.ReadExactly(payload); + return payload; + } + } + + return packetStream.ReadRemaining(); + } + + private static async Task ReadPacketPayloadAsync(PacketReadStream packetStream, int compressionThreshold, CancellationToken cancellationToken) + { + if (compressionThreshold >= 0) + { + int uncompressedLength = await ReadNextVarIntRawAsync(packetStream, cancellationToken); + if (uncompressedLength > 0) + { + await using ZLibStream zlibStream = new(packetStream, CompressionMode.Decompress, leaveOpen: true); + byte[] payload = GC.AllocateUninitializedArray(uncompressedLength); + await zlibStream.ReadExactlyAsync(payload.AsMemory(0, uncompressedLength), cancellationToken); + return payload; + } + } + + return await packetStream.ReadRemainingAsync(cancellationToken); + } + + private static int ReadNextVarIntRaw(Stream stream) + { + int value = 0; + int position = 0; + + while (true) + { + int current = stream.ReadByte(); + if (current < 0) + throw new IOException("Connection closed."); + + value |= (current & 0x7F) << position++ * 7; + if (position > 5) + throw new OverflowException("VarInt too big"); + if ((current & 0x80) != 0x80) + return value; + } + } + + private static async Task ReadNextVarIntRawAsync(Stream stream, CancellationToken cancellationToken) + { + byte[] buffer = new byte[1]; + int value = 0; + int position = 0; + + while (true) + { + await stream.ReadExactlyAsync(buffer.AsMemory(0, 1), cancellationToken); + byte current = buffer[0]; + + value |= (current & 0x7F) << position++ * 7; + if (position > 5) + throw new OverflowException("VarInt too big"); + if ((current & 0x80) != 0x80) + return value; + } + } } } diff --git a/MinecraftClient/Protocol/Message/ChatParser.cs b/MinecraftClient/Protocol/Message/ChatParser.cs index e0954ee2ca..b0dcbc8fd8 100644 --- a/MinecraftClient/Protocol/Message/ChatParser.cs +++ b/MinecraftClient/Protocol/Message/ChatParser.cs @@ -7,6 +7,7 @@ using System.Text; using System.Text.Json; using System.Text.RegularExpressions; +using System.Threading; using System.Threading.Tasks; using static MinecraftClient.Settings; @@ -231,6 +232,8 @@ private static string Color2tag(string colorname) /// Specify whether translation rules have been loaded /// private static bool RulesInitialized = false; + private static readonly Lock RulesInitializationLock = new(); + private static Task? RulesRefreshTask = null; /// /// Set of translation rules for formatting text @@ -243,23 +246,25 @@ private static string Color2tag(string colorname) /// public static void InitTranslations() { - if (!RulesInitialized) + lock (RulesInitializationLock) { - InitRules(); + if (RulesInitialized) + return; + RulesInitialized = true; + RulesRefreshTask = InitRulesAsync(); + _ = ObserveInitRulesAsync(RulesRefreshTask); } } /// - /// Internal rule initialization method. Looks for local rule file or download it from Mojang asset servers. + /// Internal rule initialization method. Looks for local rule file and refreshes it from Mojang asset servers if needed. /// - private static void InitRules() + private static async Task InitRulesAsync() { if (Config.Main.Advanced.Language == "en_us") { - TranslationRules = - JsonSerializer.Deserialize>( - (byte[])MinecraftAssets.ResourceManager.GetObject("en_us.json")!)!; + TranslationRules = LoadEmbeddedTranslationRules(); return; } @@ -269,21 +274,9 @@ private static void InitRules() string languageFilePath = "lang" + Path.DirectorySeparatorChar + Config.Main.Advanced.Language + ".json"; - // Load the external dictionary of translation rules or display an error message - if (File.Exists(languageFilePath)) - { - try - { - TranslationRules = - JsonSerializer.Deserialize>(File.OpenRead(languageFilePath))!; - } - catch (IOException) - { - } - catch (JsonException) - { - } - } + if (TryLoadTranslationRulesFromFile(languageFilePath, out Dictionary? translationRules)) + TranslationRules = translationRules; + else TranslationRules = LoadEmbeddedTranslationRules(); if (TranslationRules.TryGetValue("Version", out string? version) && version == Settings.TranslationsFile_Version) @@ -296,14 +289,12 @@ private static void InitRules() // Try downloading language file from Mojang's servers? ConsoleIO.WriteLineFormatted( "§8" + string.Format(Translations.chat_download, Config.Main.Advanced.Language)); - HttpClient httpClient = new(); + using HttpClient httpClient = new(); try { - Task fetch_index = httpClient.GetStringAsync(TranslationsFile_Website_Index); - fetch_index.Wait(); - Match match = Regex.Match(fetch_index.Result, + string fetchIndex = await httpClient.GetStringAsync(TranslationsFile_Website_Index); + Match match = Regex.Match(fetchIndex, $"minecraft/lang/{Config.Main.Advanced.Language}.json" + @""":\s\{""hash"":\s""([\d\w]{40})"""); - fetch_index.Dispose(); if (match.Success && match.Groups.Count == 2) { string hash = match.Groups[1].Value; @@ -312,22 +303,19 @@ private static void InitRules() ConsoleIO.WriteLineFormatted( string.Format(Translations.chat_request, translation_file_location)); - Task?> fetckFileTask = - httpClient.GetFromJsonAsync>(translation_file_location); - fetckFileTask.Wait(); - if (fetckFileTask.Result is not null && fetckFileTask.Result.Count > 0) + Dictionary? fetchedFile = + await httpClient.GetFromJsonAsync>(translation_file_location); + if (fetchedFile is not null && fetchedFile.Count > 0) { - TranslationRules = fetckFileTask.Result; + TranslationRules = fetchedFile; TranslationRules["Version"] = TranslationsFile_Version; - File.WriteAllText(languageFilePath, + await File.WriteAllTextAsync(languageFilePath, JsonSerializer.Serialize(TranslationRules, typeof(Dictionary)), Encoding.UTF8); ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.chat_done, languageFilePath)); return; } - - fetckFileTask.Dispose(); } else { @@ -350,15 +338,50 @@ private static void InitRules() if (Config.Logging.DebugMessages && !string.IsNullOrEmpty(e.StackTrace)) ConsoleIO.WriteLine(e.StackTrace); } - finally + TranslationRules = LoadEmbeddedTranslationRules(); + ConsoleIO.WriteLine(Translations.chat_use_default); + } + + private static async Task ObserveInitRulesAsync(Task initRulesTask) + { + try + { + await initRulesTask; + } + catch (Exception e) { - httpClient.Dispose(); + TranslationRules = LoadEmbeddedTranslationRules(); + if (Config.Logging.DebugMessages) + ConsoleIO.WriteLine(e.ToString()); } + } - TranslationRules = - JsonSerializer.Deserialize>( - (byte[])MinecraftAssets.ResourceManager.GetObject("en_us.json")!)!; - ConsoleIO.WriteLine(Translations.chat_use_default); + private static Dictionary LoadEmbeddedTranslationRules() + { + return JsonSerializer.Deserialize>( + (byte[])MinecraftAssets.ResourceManager.GetObject("en_us.json")!)!; + } + + private static bool TryLoadTranslationRulesFromFile(string languageFilePath, out Dictionary? translationRules) + { + translationRules = null; + if (!File.Exists(languageFilePath)) + return false; + + try + { + translationRules = + JsonSerializer.Deserialize>(File.OpenRead(languageFilePath))!; + return translationRules is not null; + } + catch (IOException) + { + return false; + } + catch (JsonException) + { + return false; + } } public static string? TranslateString(string rulename) @@ -379,10 +402,7 @@ private static void InitRules() private static string TranslateString(string rulename, List using_data) { if (!RulesInitialized) - { - InitRules(); - RulesInitialized = true; - } + InitTranslations(); if (TranslationRules.ContainsKey(rulename)) { @@ -617,4 +637,4 @@ private static string NbtToString(Dictionary nbt, string formatt return formatting + message + extraBuilder.ToString(); } } -} \ No newline at end of file +} diff --git a/MinecraftClient/Protocol/MicrosoftAuthentication.cs b/MinecraftClient/Protocol/MicrosoftAuthentication.cs index 4c84ce477e..1389fd8870 100644 --- a/MinecraftClient/Protocol/MicrosoftAuthentication.cs +++ b/MinecraftClient/Protocol/MicrosoftAuthentication.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Runtime.InteropServices; using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient.Protocol { @@ -37,7 +38,13 @@ public static LoginResponse RequestAccessToken(string code) { string postData = "client_id={0}&grant_type=authorization_code&redirect_uri=https%3A%2F%2Fmccteam.github.io%2Fredirect.html&code={1}"; postData = string.Format(postData, clientId, code); - return RequestToken(postData); + return RequestTokenAsync(postData).GetAwaiter().GetResult(); + } + + public static Task RequestAccessTokenAsync(string code) + { + string postData = "client_id={0}&grant_type=authorization_code&redirect_uri=https%3A%2F%2Fmccteam.github.io%2Fredirect.html&code={1}"; + return RequestTokenAsync(string.Format(postData, clientId, code)); } /// @@ -49,7 +56,13 @@ public static LoginResponse RefreshAccessToken(string refreshToken) { string postData = "client_id={0}&grant_type=refresh_token&redirect_uri=https%3A%2F%2Fmccteam.github.io%2Fredirect.html&refresh_token={1}"; postData = string.Format(postData, clientId, refreshToken); - return RequestToken(postData); + return RequestTokenAsync(postData).GetAwaiter().GetResult(); + } + + public static Task RefreshAccessTokenAsync(string refreshToken) + { + string postData = "client_id={0}&grant_type=refresh_token&redirect_uri=https%3A%2F%2Fmccteam.github.io%2Fredirect.html&refresh_token={1}"; + return RequestTokenAsync(string.Format(postData, clientId, refreshToken)); } /// @@ -58,6 +71,11 @@ public static LoginResponse RefreshAccessToken(string refreshToken) /// /// Device code response for user to complete authentication public static DeviceCodeResponse RequestDeviceCode() + { + return RequestDeviceCodeAsync().GetAwaiter().GetResult(); + } + + public static async Task RequestDeviceCodeAsync(CancellationToken cancellationToken = default) { string postData = string.Format("client_id={0}&scope=XboxLive.signin%20offline_access%20openid%20email", clientId); @@ -65,7 +83,7 @@ public static DeviceCodeResponse RequestDeviceCode() { UserAgent = "MCC/" + Program.Version }; - var response = request.Post("application/x-www-form-urlencoded", postData); + var response = await request.PostAsync("application/x-www-form-urlencoded", postData, cancellationToken); var jsonData = Json.ParseJson(response.Body); if (jsonData?["error"] is not null) @@ -93,6 +111,11 @@ public static DeviceCodeResponse RequestDeviceCode() /// Polling interval in seconds /// Login response with access token and refresh token public static LoginResponse PollDeviceCodeToken(string deviceCode, int expiresIn, int interval) + { + return PollDeviceCodeTokenAsync(deviceCode, expiresIn, interval).GetAwaiter().GetResult(); + } + + public static async Task PollDeviceCodeTokenAsync(string deviceCode, int expiresIn, int interval, CancellationToken cancellationToken = default) { // Per OAuth 2.0 device code spec, server may respond with "slow_down" requiring // the client to increase its polling interval by this amount @@ -107,13 +130,13 @@ public static LoginResponse PollDeviceCodeToken(string deviceCode, int expiresIn while (stopwatch.Elapsed.TotalSeconds < expiresIn) { - Thread.Sleep(pollInterval * 1000); + await Task.Delay(TimeSpan.FromSeconds(pollInterval), cancellationToken); var request = new ProxiedWebRequest(tokenUrl) { UserAgent = "MCC/" + Program.Version }; - var response = request.Post("application/x-www-form-urlencoded", postData); + var response = await request.PostAsync("application/x-www-form-urlencoded", postData, cancellationToken); var jsonData = Json.ParseJson(response.Body); if (jsonData?["error"] is not null) @@ -173,12 +196,17 @@ public static LoginResponse PollDeviceCodeToken(string deviceCode, int expiresIn /// Complete POST data for the request /// private static LoginResponse RequestToken(string postData) + { + return RequestTokenAsync(postData).GetAwaiter().GetResult(); + } + + private static async Task RequestTokenAsync(string postData, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(tokenUrl) { UserAgent = "MCC/" + Program.Version }; - var response = request.Post("application/x-www-form-urlencoded", postData); + var response = await request.PostAsync("application/x-www-form-urlencoded", postData, cancellationToken); var jsonData = Json.ParseJson(response.Body); // Error handling @@ -271,6 +299,11 @@ static class XboxLive /// /// public static XblAuthenticateResponse XblAuthenticate(Microsoft.LoginResponse loginResponse) + { + return XblAuthenticateAsync(loginResponse).GetAwaiter().GetResult(); + } + + public static async Task XblAuthenticateAsync(Microsoft.LoginResponse loginResponse, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(xbl) { @@ -291,7 +324,7 @@ public static XblAuthenticateResponse XblAuthenticate(Microsoft.LoginResponse lo + "\"RelyingParty\": \"http://auth.xboxlive.com\"," + "\"TokenType\": \"JWT\"" + "}"; - var response = request.Post("application/json", payload); + var response = await request.PostAsync("application/json", payload, cancellationToken); if (Settings.Config.Logging.DebugMessages) { ConsoleIO.WriteLine(response.ToString()); @@ -321,6 +354,11 @@ public static XblAuthenticateResponse XblAuthenticate(Microsoft.LoginResponse lo /// /// public static XSTSAuthenticateResponse XSTSAuthenticate(XblAuthenticateResponse xblResponse) + { + return XSTSAuthenticateAsync(xblResponse).GetAwaiter().GetResult(); + } + + public static async Task XSTSAuthenticateAsync(XblAuthenticateResponse xblResponse, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(xsts) { @@ -339,7 +377,7 @@ public static XSTSAuthenticateResponse XSTSAuthenticate(XblAuthenticateResponse + "\"RelyingParty\": \"rp://api.minecraftservices.com/\"," + "\"TokenType\": \"JWT\"" + "}"; - var response = request.Post("application/json", payload); + var response = await request.PostAsync("application/json", payload, cancellationToken); if (Settings.Config.Logging.DebugMessages) { ConsoleIO.WriteLine(response.ToString()); @@ -404,6 +442,11 @@ static class MinecraftWithXbox /// /// public static string LoginWithXbox(string userHash, string xstsToken) + { + return LoginWithXboxAsync(userHash, xstsToken).GetAwaiter().GetResult(); + } + + public static async Task LoginWithXboxAsync(string userHash, string xstsToken, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(loginWithXbox) { @@ -411,7 +454,7 @@ public static string LoginWithXbox(string userHash, string xstsToken) }; string payload = "{\"identityToken\": \"XBL3.0 x=" + userHash + ";" + xstsToken + "\"}"; - var response = request.Post("application/json", payload); + var response = await request.PostAsync("application/json", payload, cancellationToken); if (Settings.Config.Logging.DebugMessages) { @@ -430,10 +473,15 @@ public static string LoginWithXbox(string userHash, string xstsToken) /// /// True if the user own the game public static bool UserHasGame(string accessToken) + { + return UserHasGameAsync(accessToken).GetAwaiter().GetResult(); + } + + public static async Task UserHasGameAsync(string accessToken, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(ownership); request.Headers.Add("Authorization", string.Format("Bearer {0}", accessToken)); - var response = request.Get(); + var response = await request.GetAsync(cancellationToken); if (Settings.Config.Logging.DebugMessages) { @@ -446,10 +494,15 @@ public static bool UserHasGame(string accessToken) } public static UserProfile GetUserProfile(string accessToken) + { + return GetUserProfileAsync(accessToken).GetAwaiter().GetResult(); + } + + public static async Task GetUserProfileAsync(string accessToken, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(profile); request.Headers.Add("Authorization", string.Format("Bearer {0}", accessToken)); - var response = request.Get(); + var response = await request.GetAsync(cancellationToken); if (Settings.Config.Logging.DebugMessages) { diff --git a/MinecraftClient/Protocol/MojangAPI.cs b/MinecraftClient/Protocol/MojangAPI.cs index b05e6a7144..9e91c9ddbc 100644 --- a/MinecraftClient/Protocol/MojangAPI.cs +++ b/MinecraftClient/Protocol/MojangAPI.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Net.Http; +using System.Threading; using System.Threading.Tasks; /// !!! ATTENTION !!! @@ -103,14 +104,15 @@ private static ServiceStatus StringToServiceStatus(string s) /// /// Playername /// UUID as string - public static string NameToUuid(string name) + public static string NameToUuid(string name) => + NameToUuidAsync(name).GetAwaiter().GetResult(); + + public static async Task NameToUuidAsync(string name, CancellationToken cancellationToken = default) { try { - Task fetchTask = httpClient.GetStringAsync("https://api.mojang.com/users/profiles/minecraft/" + name); - fetchTask.Wait(); - string result = Json.ParseJson(fetchTask.Result)!["id"]!.GetStringValue(); - fetchTask.Dispose(); + string responseBody = await httpClient.GetStringAsync("https://api.mojang.com/users/profiles/minecraft/" + name, cancellationToken); + string result = Json.ParseJson(responseBody)!["id"]!.GetStringValue(); return result; } catch (Exception) { return string.Empty; } @@ -121,15 +123,16 @@ public static string NameToUuid(string name) /// /// UUID of a player /// Players UUID - public static string UuidToCurrentName(string uuid) + public static string UuidToCurrentName(string uuid) => + UuidToCurrentNameAsync(uuid).GetAwaiter().GetResult(); + + public static async Task UuidToCurrentNameAsync(string uuid, CancellationToken cancellationToken = default) { // Perform web request try { - Task fetchTask = httpClient.GetStringAsync("https://api.mojang.com/user/profiles/" + uuid + "/names"); - fetchTask.Wait(); - var nameChanges = Json.ParseJson(fetchTask.Result)!.AsArray(); - fetchTask.Dispose(); + string responseBody = await httpClient.GetStringAsync("https://api.mojang.com/user/profiles/" + uuid + "/names", cancellationToken); + var nameChanges = Json.ParseJson(responseBody)!.AsArray(); // Names are sorted from past to most recent. We need to get the last name in the list return nameChanges[^1]!["name"]!.GetStringValue(); @@ -142,7 +145,10 @@ public static string UuidToCurrentName(string uuid) /// /// UUID of a player /// Name history, as a dictionary - public static Dictionary UuidToNameHistory(string uuid) + public static Dictionary UuidToNameHistory(string uuid) => + UuidToNameHistoryAsync(uuid).GetAwaiter().GetResult(); + + public static async Task> UuidToNameHistoryAsync(string uuid, CancellationToken cancellationToken = default) { Dictionary tempDict = new(); System.Text.Json.Nodes.JsonArray jsonDataList; @@ -150,10 +156,8 @@ public static Dictionary UuidToNameHistory(string uuid) // Perform web request try { - Task fetchTask = httpClient.GetStringAsync("https://api.mojang.com/user/profiles/" + uuid + "/names"); - fetchTask.Wait(); - jsonDataList = Json.ParseJson(fetchTask.Result)!.AsArray(); - fetchTask.Dispose(); + string responseBody = await httpClient.GetStringAsync("https://api.mojang.com/user/profiles/" + uuid + "/names", cancellationToken); + jsonDataList = Json.ParseJson(responseBody)!.AsArray(); } catch (Exception) { return tempDict; } @@ -181,17 +185,18 @@ public static Dictionary UuidToNameHistory(string uuid) /// Get the Mojang API status /// /// Dictionary of the Mojang services - public static MojangServiceStatus GetMojangServiceStatus() + public static MojangServiceStatus GetMojangServiceStatus() => + GetMojangServiceStatusAsync().GetAwaiter().GetResult(); + + public static async Task GetMojangServiceStatusAsync(CancellationToken cancellationToken = default) { System.Text.Json.Nodes.JsonArray jsonDataList; // Perform web request try { - Task fetchTask = httpClient.GetStringAsync("https://status.mojang.com/check"); - fetchTask.Wait(); - jsonDataList = Json.ParseJson(fetchTask.Result)!.AsArray(); - fetchTask.Dispose(); + string responseBody = await httpClient.GetStringAsync("https://status.mojang.com/check", cancellationToken); + jsonDataList = Json.ParseJson(responseBody)!.AsArray(); } catch (Exception) { @@ -215,7 +220,10 @@ public static MojangServiceStatus GetMojangServiceStatus() /// /// UUID of a player /// Dictionary with a link to the skin and cape of a player. - public static SkinInfo GetSkinInfo(string uuid) + public static SkinInfo GetSkinInfo(string uuid) => + GetSkinInfoAsync(uuid).GetAwaiter().GetResult(); + + public static async Task GetSkinInfoAsync(string uuid, CancellationToken cancellationToken = default) { System.Text.Json.Nodes.JsonObject textureObj; string base64SkinInfo; @@ -224,11 +232,9 @@ public static SkinInfo GetSkinInfo(string uuid) // Perform web request try { - Task fetchTask = httpClient.GetStringAsync("https://sessionserver.mojang.com/session/minecraft/profile/" + uuid); - fetchTask.Wait(); + string responseBody = await httpClient.GetStringAsync("https://sessionserver.mojang.com/session/minecraft/profile/" + uuid, cancellationToken); // Obtain the Base64 encoded skin information from the API. Discard the rest, since it can be obtained easier through other requests. - base64SkinInfo = Json.ParseJson(fetchTask.Result)!["properties"]![0]!["value"]!.GetStringValue(); - fetchTask.Dispose(); + base64SkinInfo = Json.ParseJson(responseBody)!["properties"]![0]!["value"]!.GetStringValue(); } catch (Exception) { return new SkinInfo(); } diff --git a/MinecraftClient/Protocol/PacketPipeline/PacketReadStream.cs b/MinecraftClient/Protocol/PacketPipeline/PacketReadStream.cs new file mode 100644 index 0000000000..167eb03b77 --- /dev/null +++ b/MinecraftClient/Protocol/PacketPipeline/PacketReadStream.cs @@ -0,0 +1,203 @@ +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace MinecraftClient.Protocol.PacketPipeline; + +internal sealed class PacketReadStream : Stream +{ + private const int DrainBufferSize = 4096; + + private readonly Stream baseStream; + private readonly byte[] singleByteBuffer = new byte[1]; + private int remainingLength; + + public PacketReadStream(Stream baseStream, int packetLength) + { + ArgumentNullException.ThrowIfNull(baseStream); + if (packetLength < 0) + throw new ArgumentOutOfRangeException(nameof(packetLength)); + + this.baseStream = baseStream; + remainingLength = packetLength; + } + + public int RemainingLength => remainingLength; + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => false; + public override long Length => throw new NotSupportedException(); + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (remainingLength == 0) + return 0; + + int readLength = Math.Min(count, remainingLength); + int read = baseStream.Read(buffer, offset, readLength); + remainingLength -= read; + return read; + } + + public override int Read(Span buffer) + { + if (remainingLength == 0) + return 0; + + int readLength = Math.Min(buffer.Length, remainingLength); + int read = baseStream.Read(buffer[..readLength]); + remainingLength -= read; + return read; + } + + public override int ReadByte() + { + if (remainingLength == 0) + return -1; + + int value = baseStream.ReadByte(); + if (value == -1) + return -1; + + remainingLength--; + return value; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (remainingLength == 0) + return 0; + + int readLength = Math.Min(buffer.Length, remainingLength); + int read = await baseStream.ReadAsync(buffer[..readLength], cancellationToken); + remainingLength -= read; + return read; + } + + public new async ValueTask ReadExactlyAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (buffer.Length > remainingLength) + throw new OverflowException("Reached the end of the packet."); + + await baseStream.ReadExactlyAsync(buffer, cancellationToken); + remainingLength -= buffer.Length; + } + + public new void ReadExactly(Span buffer) + { + if (buffer.Length > remainingLength) + throw new OverflowException("Reached the end of the packet."); + + baseStream.ReadExactly(buffer); + remainingLength -= buffer.Length; + } + + public byte[] ReadRemaining() + { + if (remainingLength == 0) + return []; + + byte[] buffer = GC.AllocateUninitializedArray(remainingLength); + ReadExactly(buffer); + return buffer; + } + + public async Task ReadRemainingAsync(CancellationToken cancellationToken = default) + { + if (remainingLength == 0) + return []; + + byte[] buffer = GC.AllocateUninitializedArray(remainingLength); + await ReadExactlyAsync(buffer, cancellationToken); + return buffer; + } + + public void DrainRemaining() + { + if (remainingLength == 0) + return; + + byte[] buffer = GC.AllocateUninitializedArray(Math.Min(DrainBufferSize, remainingLength)); + while (remainingLength > 0) + { + int read = baseStream.Read(buffer, 0, Math.Min(buffer.Length, remainingLength)); + if (read <= 0) + throw new EndOfStreamException("Connection closed while draining packet data."); + + remainingLength -= read; + } + } + + public async ValueTask DrainRemainingAsync(CancellationToken cancellationToken = default) + { + if (remainingLength == 0) + return; + + byte[] buffer = GC.AllocateUninitializedArray(Math.Min(DrainBufferSize, remainingLength)); + while (remainingLength > 0) + { + int read = await baseStream.ReadAsync(buffer.AsMemory(0, Math.Min(buffer.Length, remainingLength)), cancellationToken); + if (read <= 0) + throw new EndOfStreamException("Connection closed while draining packet data."); + + remainingLength -= read; + } + } + + public override void Flush() + { + throw new NotSupportedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } + + protected override void Dispose(bool disposing) + { + if (disposing && remainingLength > 0) + { + try + { + DrainRemaining(); + } + catch (IOException) { } + catch (ObjectDisposedException) { } + } + + base.Dispose(disposing); + } + + public override async ValueTask DisposeAsync() + { + if (remainingLength > 0) + { + try + { + await DrainRemainingAsync(); + } + catch (IOException) { } + catch (ObjectDisposedException) { } + } + + await base.DisposeAsync(); + } +} diff --git a/MinecraftClient/Protocol/ProtocolHandler.cs b/MinecraftClient/Protocol/ProtocolHandler.cs index a8568f3863..1a9c5b2599 100644 --- a/MinecraftClient/Protocol/ProtocolHandler.cs +++ b/MinecraftClient/Protocol/ProtocolHandler.cs @@ -7,6 +7,8 @@ using System.Net.Sockets; using System.Text; using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; using DnsClient; using MinecraftClient.Protocol.Handlers; using MinecraftClient.Protocol.Handlers.Forge; @@ -91,11 +93,30 @@ public static bool MinecraftServiceLookup(ref string domain, ref ushort port) /// TRUE if ping was successful public static bool GetServerInfo(string serverIP, ushort serverPort, ref int protocolversion, ref ForgeInfo? forgeInfo) + { + (bool success, int resolvedProtocolVersion, ForgeInfo? resolvedForgeInfo) = + GetServerInfoAsync(serverIP, serverPort, protocolversion).GetAwaiter().GetResult(); + + if (!success) + return false; + + if (protocolversion != 0 && protocolversion != resolvedProtocolVersion) + ConsoleIO.WriteLineFormatted("§8" + Translations.error_version_different, acceptnewlines: true); + if (protocolversion == 0 && resolvedProtocolVersion <= 1) + ConsoleIO.WriteLineFormatted("§8" + Translations.error_no_version_report, acceptnewlines: true); + if (protocolversion == 0) + protocolversion = resolvedProtocolVersion; + + forgeInfo = resolvedForgeInfo; + return true; + } + + public static async Task<(bool Success, int ProtocolVersion, ForgeInfo? ForgeInfo)> GetServerInfoAsync(string serverIP, ushort serverPort, int protocolversion) { bool success = false; int protocolversionTmp = 0; ForgeInfo? forgeInfoTmp = null; - if (AutoTimeout.Perform(() => + if (await AutoTimeout.PerformAsync(() => { try { @@ -118,19 +139,15 @@ public static bool GetServerInfo(string serverIP, ushort serverPort, ref int pro ? 10 : 30))) { - if (protocolversion != 0 && protocolversion != protocolversionTmp) - ConsoleIO.WriteLineFormatted("§8" + Translations.error_version_different, acceptnewlines: true); - if (protocolversion == 0 && protocolversionTmp <= 1) - ConsoleIO.WriteLineFormatted("§8" + Translations.error_no_version_report, acceptnewlines: true); if (protocolversion == 0) protocolversion = protocolversionTmp; - forgeInfo = forgeInfoTmp; - return success; + + return (success, protocolversion, forgeInfoTmp); } else { ConsoleIO.WriteLineFormatted("§8" + Translations.error_connection_timeout, acceptnewlines: true); - return false; + return (false, protocolversion, forgeInfoTmp); } } @@ -590,24 +607,23 @@ public enum AccountType /// Returns the status of the login (Success, Failure, etc.) public static LoginResult GetLogin(string user, string pass, LoginType type, out SessionToken session) { - if (type == LoginType.microsoft) - { - if (Config.Main.General.Method == LoginMethod.mcc) - return MicrosoftMCCLogin(user, pass, out session); - else - return MicrosoftBrowserLogin(out session, user); - } - else if (type == LoginType.mojang) - { - return MojangLogin(user, pass, out session); - } - else if (type == LoginType.yggdrasil) + var login = GetLoginAsync(user, pass, type).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + public static Task<(LoginResult Result, SessionToken Session)> GetLoginAsync(string user, string pass, LoginType type, CancellationToken cancellationToken = default) + { + return type switch { - return YggdrasiLogin(user, pass, out session); - } - else - throw new InvalidOperationException( - "Account type must be Mojang or Microsoft or valid authlib 3rd Servers!"); + LoginType.microsoft => Config.Main.General.Method == LoginMethod.mcc + ? MicrosoftMCCLoginAsync(user, pass, cancellationToken) + : MicrosoftBrowserLoginAsync(user, cancellationToken), + LoginType.mojang => MojangLoginAsync(user, pass, cancellationToken), + LoginType.yggdrasil => YggdrasiLoginAsync(user, pass, cancellationToken), + _ => throw new InvalidOperationException( + "Account type must be Mojang or Microsoft or valid authlib 3rd Servers!") + }; } /// @@ -619,20 +635,29 @@ public static LoginResult GetLogin(string user, string pass, LoginType type, out /// private static LoginResult MojangLogin(string user, string pass, out SessionToken session) { - session = new SessionToken() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; + var login = MojangLoginAsync(user, pass).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + private static async Task<(LoginResult Result, SessionToken Session)> MojangLoginAsync(string user, string pass, CancellationToken cancellationToken = default) + { + SessionToken session = new() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; try { - string result = ""; + string result; string json_request = "{\"agent\": { \"name\": \"Minecraft\", \"version\": 1 }, \"username\": \"" + JsonEncode(user) + "\", \"password\": \"" + JsonEncode(pass) + "\", \"clientToken\": \"" + JsonEncode(session.ClientID) + "\" }"; - int code = DoHTTPSPost("authserver.mojang.com", 443, "/authenticate", json_request, ref result); + var response = await DoHTTPSPostAsync("authserver.mojang.com", 443, "/authenticate", json_request, cancellationToken); + int code = response.StatusCode; + result = response.Result; if (code == 200) { if (result.Contains("availableProfiles\":[]}")) { - return LoginResult.NotPremium; + return (LoginResult.NotPremium, session); } else { @@ -645,27 +670,27 @@ private static LoginResult MojangLogin(string user, string pass, out SessionToke session.PlayerID = loginResponse["selectedProfile"]!["id"]!.GetStringValue(); session.PlayerName = loginResponse["selectedProfile"]!["name"]! .GetStringValue(); - return LoginResult.Success; + return (LoginResult.Success, session); } - else return LoginResult.InvalidResponse; + else return (LoginResult.InvalidResponse, session); } } else if (code == 403) { if (result.Contains("UserMigratedException")) { - return LoginResult.AccountMigrated; + return (LoginResult.AccountMigrated, session); } - else return LoginResult.WrongPassword; + else return (LoginResult.WrongPassword, session); } else if (code == 503) { - return LoginResult.ServiceUnavailable; + return (LoginResult.ServiceUnavailable, session); } else { ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.error_http_code, code)); - return LoginResult.OtherError; + return (LoginResult.OtherError, session); } } catch (System.Security.Authentication.AuthenticationException e) @@ -675,7 +700,7 @@ private static LoginResult MojangLogin(string user, string pass, out SessionToke ConsoleIO.WriteLineFormatted("§8" + e.ToString()); } - return LoginResult.SSLError; + return (LoginResult.SSLError, session); } catch (System.IO.IOException e) { @@ -686,9 +711,9 @@ private static LoginResult MojangLogin(string user, string pass, out SessionToke if (e.Message.Contains("authentication")) { - return LoginResult.SSLError; + return (LoginResult.SSLError, session); } - else return LoginResult.OtherError; + else return (LoginResult.OtherError, session); } catch (Exception e) { @@ -697,28 +722,41 @@ private static LoginResult MojangLogin(string user, string pass, out SessionToke ConsoleIO.WriteLineFormatted("§8" + e.ToString()); } - return LoginResult.OtherError; + return (LoginResult.OtherError, session); } } private static LoginResult YggdrasiLogin(string user, string pass, out SessionToken session) { - session = new SessionToken() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; + var login = YggdrasiLoginAsync(user, pass).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + private static async Task<(LoginResult Result, SessionToken Session)> YggdrasiLoginAsync(string user, string pass, CancellationToken cancellationToken = default) + { + SessionToken session = new() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; try { - string result = ""; + string result; string json_request = "{\"agent\": { \"name\": \"Minecraft\", \"version\": 1 }, \"username\": \"" + JsonEncode(user) + "\", \"password\": \"" + JsonEncode(pass) + "\", \"clientToken\": \"" + JsonEncode(session.ClientID) + "\" }"; - int code = DoHTTPSPost(Config.Main.General.AuthServer.Host, Config.Main.General.AuthServer.Port, - Config.Main.General.AuthServer.AuthlibInjectorAPIPath + "/authserver/authenticate", json_request, - Config.Main.General.AuthServer.UseHttps, ref result); + var response = await DoHTTPSPostAsync( + Config.Main.General.AuthServer.Host, + Config.Main.General.AuthServer.Port, + Config.Main.General.AuthServer.AuthlibInjectorAPIPath + "/authserver/authenticate", + json_request, + Config.Main.General.AuthServer.UseHttps, + cancellationToken); + int code = response.StatusCode; + result = response.Result; if (code == 200) { if (result.Contains("availableProfiles\":[]}")) { - return LoginResult.NotPremium; + return (LoginResult.NotPremium, session); } else { @@ -733,7 +771,7 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo .GetStringValue(); session.PlayerName = loginResponse["selectedProfile"]!["name"]! .GetStringValue(); - return LoginResult.Success; + return (LoginResult.Success, session); } else { @@ -769,33 +807,33 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo session.PlayerID = selectedProfile["id"]!.GetStringValue(); session.PlayerName = selectedProfile["name"]!.GetStringValue(); SessionToken currentsession = session; - return GetNewYggdrasilToken(currentsession, out session); + return await GetNewYggdrasilTokenAsync(currentsession, cancellationToken); } else { - return LoginResult.WrongSelection; + return (LoginResult.WrongSelection, session); } } } - else return LoginResult.InvalidResponse; + else return (LoginResult.InvalidResponse, session); } } else if (code == 403) { if (result.Contains("UserMigratedException")) { - return LoginResult.AccountMigrated; + return (LoginResult.AccountMigrated, session); } - else return LoginResult.WrongPassword; + else return (LoginResult.WrongPassword, session); } else if (code == 503) { - return LoginResult.ServiceUnavailable; + return (LoginResult.ServiceUnavailable, session); } else { ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.error_http_code, code)); - return LoginResult.OtherError; + return (LoginResult.OtherError, session); } } catch (System.Security.Authentication.AuthenticationException e) @@ -805,7 +843,7 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo ConsoleIO.WriteLineFormatted("§8" + e.ToString()); } - return LoginResult.SSLError; + return (LoginResult.SSLError, session); } catch (System.IO.IOException e) { @@ -816,9 +854,9 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo if (e.Message.Contains("authentication")) { - return LoginResult.SSLError; + return (LoginResult.SSLError, session); } - else return LoginResult.OtherError; + else return (LoginResult.OtherError, session); } catch (Exception e) { @@ -827,7 +865,7 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo ConsoleIO.WriteLineFormatted("§8" + e.ToString()); } - return LoginResult.OtherError; + return (LoginResult.OtherError, session); } } @@ -840,10 +878,17 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo /// /// private static LoginResult MicrosoftMCCLogin(string email, string password, out SessionToken session) + { + var login = MicrosoftMCCLoginAsync(email, password).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + private static async Task<(LoginResult Result, SessionToken Session)> MicrosoftMCCLoginAsync(string email, string password, CancellationToken cancellationToken = default) { try { - var deviceCode = Microsoft.RequestDeviceCode(); + var deviceCode = await Microsoft.RequestDeviceCodeAsync(cancellationToken); ConsoleIO.WriteLineFormatted(string.Format(Translations.mcc_device_code_prompt, deviceCode.VerificationUri, deviceCode.UserCode)); @@ -852,19 +897,19 @@ private static LoginResult MicrosoftMCCLogin(string email, string password, out ConsoleIO.WriteLineFormatted(Translations.mcc_device_code_waiting); - var msaResponse = Microsoft.PollDeviceCodeToken(deviceCode.DeviceCode, deviceCode.ExpiresIn, deviceCode.Interval); - return MicrosoftLogin(msaResponse, out session); + var msaResponse = await Microsoft.PollDeviceCodeTokenAsync(deviceCode.DeviceCode, deviceCode.ExpiresIn, deviceCode.Interval, cancellationToken); + return await MicrosoftLoginAsync(msaResponse, cancellationToken); } catch (Exception e) { - session = new SessionToken() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; + SessionToken session = new() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; ConsoleIO.WriteLineFormatted("§cMicrosoft authenticate failed: " + e.Message); if (Settings.Config.Logging.DebugMessages) { ConsoleIO.WriteLineFormatted("§c" + e.StackTrace); } - return LoginResult.OtherError; + return (LoginResult.OtherError, session); } } @@ -879,6 +924,13 @@ private static LoginResult MicrosoftMCCLogin(string email, string password, out /// /// public static LoginResult MicrosoftBrowserLogin(out SessionToken session, string loginHint = "") + { + var login = MicrosoftBrowserLoginAsync(loginHint).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + public static async Task<(LoginResult Result, SessionToken Session)> MicrosoftBrowserLoginAsync(string loginHint = "", CancellationToken cancellationToken = default) { if (string.IsNullOrEmpty(loginHint)) Microsoft.OpenBrowser(Microsoft.SignInUrl); @@ -891,40 +943,54 @@ public static LoginResult MicrosoftBrowserLogin(out SessionToken session, string string code = ConsoleIO.ReadLine(); ConsoleIO.WriteLine(string.Format(Translations.mcc_connecting, "Microsoft")); - var msaResponse = Microsoft.RequestAccessToken(code); - return MicrosoftLogin(msaResponse, out session); + var msaResponse = await Microsoft.RequestAccessTokenAsync(code); + return await MicrosoftLoginAsync(msaResponse, cancellationToken); } public static LoginResult MicrosoftLoginRefresh(string refreshToken, out SessionToken session) { - var msaResponse = Microsoft.RefreshAccessToken(refreshToken); - return MicrosoftLogin(msaResponse, out session); + var login = MicrosoftLoginRefreshAsync(refreshToken).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + public static async Task<(LoginResult Result, SessionToken Session)> MicrosoftLoginRefreshAsync(string refreshToken, CancellationToken cancellationToken = default) + { + var msaResponse = await Microsoft.RefreshAccessTokenAsync(refreshToken); + return await MicrosoftLoginAsync(msaResponse, cancellationToken); } private static LoginResult MicrosoftLogin(Microsoft.LoginResponse msaResponse, out SessionToken session) { - session = new SessionToken() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; + var login = MicrosoftLoginAsync(msaResponse).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + private static async Task<(LoginResult Result, SessionToken Session)> MicrosoftLoginAsync(Microsoft.LoginResponse msaResponse, CancellationToken cancellationToken = default) + { + SessionToken session = new() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; try { - var xblResponse = XboxLive.XblAuthenticate(msaResponse); - var xsts = XboxLive.XSTSAuthenticate(xblResponse); // Might throw even password correct + var xblResponse = await XboxLive.XblAuthenticateAsync(msaResponse, cancellationToken); + var xsts = await XboxLive.XSTSAuthenticateAsync(xblResponse, cancellationToken); // Might throw even password correct - string accessToken = MinecraftWithXbox.LoginWithXbox(xsts.UserHash, xsts.Token); - bool hasGame = MinecraftWithXbox.UserHasGame(accessToken); + string accessToken = await MinecraftWithXbox.LoginWithXboxAsync(xsts.UserHash, xsts.Token, cancellationToken); + bool hasGame = await MinecraftWithXbox.UserHasGameAsync(accessToken, cancellationToken); if (hasGame) { - var profile = MinecraftWithXbox.GetUserProfile(accessToken); + var profile = await MinecraftWithXbox.GetUserProfileAsync(accessToken, cancellationToken); session.PlayerName = profile.UserName; session.PlayerID = profile.UUID; session.ID = accessToken; session.RefreshToken = msaResponse.RefreshToken; InternalConfig.Account.Login = msaResponse.Email; - return LoginResult.Success; + return (LoginResult.Success, session); } else { - return LoginResult.NotPremium; + return (LoginResult.NotPremium, session); } } catch (Exception e) @@ -935,7 +1001,7 @@ private static LoginResult MicrosoftLogin(Microsoft.LoginResponse msaResponse, o ConsoleIO.WriteLineFormatted("§c" + e.StackTrace); } - return LoginResult.WrongPassword; // Might not always be wrong password + return (LoginResult.WrongPassword, session); // Might not always be wrong password } } @@ -1075,6 +1141,52 @@ public static LoginResult GetNewYggdrasilToken(SessionToken currentsession, out } } + public static async Task<(LoginResult Result, SessionToken Session)> GetNewYggdrasilTokenAsync(SessionToken currentsession, CancellationToken cancellationToken = default) + { + SessionToken session = new(); + try + { + string json_request = "{ \"accessToken\": \"" + JsonEncode(currentsession.ID) + + "\", \"clientToken\": \"" + JsonEncode(currentsession.ClientID) + + "\", \"selectedProfile\": { \"id\": \"" + JsonEncode(currentsession.PlayerID) + + "\", \"name\": \"" + JsonEncode(currentsession.PlayerName) + "\" } }"; + var response = await DoHTTPSPostAsync( + Config.Main.General.AuthServer.Host, + Config.Main.General.AuthServer.Port, + Config.Main.General.AuthServer.AuthlibInjectorAPIPath + "/authserver/refresh", + json_request, + Config.Main.General.AuthServer.UseHttps, + cancellationToken); + string result = response.Result; + int code = response.StatusCode; + if (code == 200) + { + var loginResponse = Json.ParseJson(result); + if (loginResponse?["accessToken"] is not null + && loginResponse["selectedProfile"]?["id"] is not null + && loginResponse["selectedProfile"]?["name"] is not null) + { + session.ID = loginResponse["accessToken"]!.GetStringValue(); + session.PlayerID = loginResponse["selectedProfile"]!["id"]!.GetStringValue(); + session.PlayerName = loginResponse["selectedProfile"]!["name"]!.GetStringValue(); + return (LoginResult.Success, session); + } + + return (LoginResult.InvalidResponse, session); + } + + if (code == 403 && result.Contains("InvalidToken")) + return (LoginResult.InvalidToken, session); + + ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.error_auth, code)); + return (LoginResult.OtherError, session); + } + catch + { + return (LoginResult.OtherError, session); + } + } + /// /// Check session using Mojang's Yggdrasil authentication scheme. Allows to join an online-mode server /// @@ -1108,6 +1220,43 @@ public static bool SessionCheck(string uuid, string accesstoken, string serverha } } + public static async Task SessionCheckAsync(string uuid, string accesstoken, string serverhash, LoginType type) + { + try + { + string jsonRequest = "{\"accessToken\":\"" + accesstoken + "\",\"selectedProfile\":\"" + uuid + + "\",\"serverId\":\"" + serverhash + "\"}"; + string host = type == LoginType.yggdrasil + ? Config.Main.General.AuthServer.Host + : "sessionserver.mojang.com"; + int port = type == LoginType.yggdrasil ? Config.Main.General.AuthServer.Port : 443; + string endpoint = type == LoginType.yggdrasil + ? Config.Main.General.AuthServer.AuthlibInjectorAPIPath + "/sessionserver/session/minecraft/join" + : "/session/minecraft/join"; + + bool useHttps = type == LoginType.yggdrasil ? Config.Main.General.AuthServer.UseHttps : true; + var response = await DoHTTPSRequestAsync( + HttpMethod.Post, + host, + port, + endpoint, + new Dictionary + { + { "Accept", "application/json" }, + { "Content-Type", "application/json" } + }, + jsonRequest, + useHttps, + CancellationToken.None); + + return response.StatusCode >= 200 && response.StatusCode < 300; + } + catch + { + return false; + } + } + /// /// Retrieve available Realms worlds of a player and display them /// @@ -1115,15 +1264,18 @@ public static bool SessionCheck(string uuid, string accesstoken, string serverha /// Player UUID /// Access token /// List of ID of available Realms worlds - public static List RealmsListWorlds(string username, string uuid, string accesstoken) + public static List RealmsListWorlds(string username, string uuid, string accesstoken) => + RealmsListWorldsAsync(username, uuid, accesstoken).GetAwaiter().GetResult(); + + public static async Task> RealmsListWorldsAsync(string username, string uuid, string accesstoken, CancellationToken cancellationToken = default) { List realmsWorldsResult = new(); // Store world ID try { - string result = ""; + string result; string cookies = String.Format("sid=token:{0}:{1};user={2};version={3}", accesstoken, uuid, username, Program.MCHighestVersion); - DoHTTPSGet("pc.realms.minecraft.net", 443, "/worlds", cookies, ref result); + (_, result) = await DoHTTPSGetAsync("pc.realms.minecraft.net", 443, "/worlds", cookies, cancellationToken); var realmsWorlds = Json.ParseJson(result); if (realmsWorlds?["servers"] is System.Text.Json.Nodes.JsonArray serversArray && serversArray.Count > 0) @@ -1179,15 +1331,21 @@ public static List RealmsListWorlds(string username, string uuid, string /// Access token /// Server address (host:port) or empty string if failure public static string GetRealmsWorldServerAddress(string worldId, string username, string uuid, - string accesstoken) + string accesstoken) => + GetRealmsWorldServerAddressAsync(worldId, username, uuid, accesstoken).GetAwaiter().GetResult(); + + public static async Task GetRealmsWorldServerAddressAsync(string worldId, string username, string uuid, + string accesstoken, CancellationToken cancellationToken = default) { try { - string result = ""; + string result; string cookies = String.Format("sid=token:{0}:{1};user={2};version={3}", accesstoken, uuid, username, Program.MCHighestVersion); - int statusCode = DoHTTPSGet("pc.realms.minecraft.net", 443, "/worlds/v1/" + worldId + "/join/pc", - cookies, ref result); + var response = await DoHTTPSGetAsync("pc.realms.minecraft.net", 443, "/worlds/v1/" + worldId + "/join/pc", + cookies, cancellationToken); + int statusCode = response.StatusCode; + result = response.Result; if (statusCode == 200) { var serverAddress = Json.ParseJson(result); @@ -1227,6 +1385,13 @@ public static string GetRealmsWorldServerAddress(string worldId, string username /// Request result /// HTTP Status code private static int DoHTTPSGet(string host, int port, string path, string cookies, ref string result) + { + var response = DoHTTPSGetAsync(host, port, path, cookies).GetAwaiter().GetResult(); + result = response.Result; + return response.StatusCode; + } + + private static Task<(int StatusCode, string Result)> DoHTTPSGetAsync(string host, int port, string path, string cookies, CancellationToken cancellationToken = default) { Dictionary headers = new() { @@ -1235,7 +1400,7 @@ private static int DoHTTPSGet(string host, int port, string path, string cookies { "Pragma", "no-cache" }, { "User-Agent", "Java/1.6.0_27" } }; - return DoHTTPSRequest(HttpMethod.Get, host, port, path, headers, null, useHttps: true, ref result); + return DoHTTPSRequestAsync(HttpMethod.Get, host, port, path, headers, null, useHttps: true, cancellationToken); } /// @@ -1261,13 +1426,23 @@ private static int DoHTTPSPost(string host, int port, string path, string body, /// Request result /// HTTP Status code private static int DoHTTPSPost(string host, int port, string path, string body, bool useHttps, ref string result) + { + var response = DoHTTPSPostAsync(host, port, path, body, useHttps).GetAwaiter().GetResult(); + result = response.Result; + return response.StatusCode; + } + + private static Task<(int StatusCode, string Result)> DoHTTPSPostAsync(string host, int port, string path, string body, CancellationToken cancellationToken = default) => + DoHTTPSPostAsync(host, port, path, body, useHttps: true, cancellationToken); + + private static Task<(int StatusCode, string Result)> DoHTTPSPostAsync(string host, int port, string path, string body, bool useHttps, CancellationToken cancellationToken = default) { Dictionary headers = new() { { "User-Agent", "MCC/" + Program.Version }, { "Content-Type", "application/json" } }; - return DoHTTPSRequest(HttpMethod.Post, host, port, path, headers, body, useHttps, ref result); + return DoHTTPSRequestAsync(HttpMethod.Post, host, port, path, headers, body, useHttps, cancellationToken); } /// @@ -1284,69 +1459,62 @@ private static int DoHTTPSPost(string host, int port, string path, string body, /// HTTP Status code private static int DoHTTPSRequest(HttpMethod method, string host, int port, string path, Dictionary headers, string? body, bool useHttps, ref string result) { - string? postResult = null; - int statusCode = 520; - Exception? exception = null; - AutoTimeout.Perform(() => + var response = DoHTTPSRequestAsync(method, host, port, path, headers, body, useHttps, CancellationToken.None) + .GetAwaiter() + .GetResult(); + result = response.Result; + return response.StatusCode; + } + + private static async Task<(int StatusCode, string Result)> DoHTTPSRequestAsync(HttpMethod method, string host, int port, string path, Dictionary headers, string? body, bool useHttps, CancellationToken cancellationToken) + { + if (Settings.Config.Logging.DebugMessages) + ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.debug_request, host)); + + using SocketsHttpHandler handler = new(); + handler.ConnectCallback = async (ctx, ct) => { - try - { - if (Settings.Config.Logging.DebugMessages) - ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.debug_request, host)); + TcpClient client = ProxyHandler.NewTcpClient(host, port, true); + return client.GetStream(); + }; - using SocketsHttpHandler handler = new SocketsHttpHandler(); - handler.ConnectCallback = async (ctx, ct) => - { - TcpClient client = ProxyHandler.NewTcpClient(host, port, true); - return client.GetStream(); - }; + using HttpClient client = new(handler); - using HttpClient client = new HttpClient(handler); + string scheme = useHttps ? "https" : "http"; + using HttpRequestMessage request = new(method, scheme + "://" + host + ":" + port + path); - string scheme = useHttps ? "https" : "http"; - var request = new HttpRequestMessage(method, scheme + "://" + host + ":" + port + path); + string contentType = "text/plain"; + foreach (var header in headers) + { + request.Headers.TryAddWithoutValidation(header.Key, header.Value); + if (header.Key.Equals("Content-Type", StringComparison.OrdinalIgnoreCase)) + contentType = header.Value; + } - var contentType = "text/plain"; - foreach (var header in headers) - { - request.Headers.TryAddWithoutValidation(header.Key, header.Value); - if (header.Key.Equals("Content-Type", StringComparison.OrdinalIgnoreCase)) - contentType = header.Value; - } + if (body is not null) + request.Content = new StringContent(body, Encoding.UTF8, contentType); - if (body is not null) - request.Content = new StringContent(body, Encoding.UTF8, contentType); + if (Settings.Config.Logging.DebugMessages) + ConsoleIO.WriteLineFormatted("§8> " + request); - if (Settings.Config.Logging.DebugMessages) - ConsoleIO.WriteLineFormatted("§8> " + request); + using CancellationTokenSource timeoutCancellationTokenSource = + CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + timeoutCancellationTokenSource.CancelAfter(TimeSpan.FromSeconds(30)); - HttpResponseMessage response = client.SendAsync(request).GetAwaiter().GetResult(); - statusCode = (int)response.StatusCode; + using HttpResponseMessage response = await client.SendAsync(request, timeoutCancellationTokenSource.Token); + int statusCode = (int)response.StatusCode; + string responseBody = statusCode == 204 + ? "No Content" + : await response.Content.ReadAsStringAsync(timeoutCancellationTokenSource.Token); - postResult = statusCode == 204 - ? "No Content" - : response.Content.ReadAsStringAsync().GetAwaiter().GetResult(); + if (Settings.Config.Logging.DebugMessages) + { + ConsoleIO.WriteLine(""); + foreach (string line in responseBody.Split('\n')) + ConsoleIO.WriteLineFormatted("§8< " + line); + } - if (Settings.Config.Logging.DebugMessages) - { - ConsoleIO.WriteLine(""); - foreach (string line in postResult.Split('\n')) - ConsoleIO.WriteLineFormatted("§8< " + line); - } - } - catch (Exception e) - { - if (e is not System.Threading.ThreadAbortException) - { - exception = e; - } - } - }, TimeSpan.FromSeconds(30)); - if (postResult is not null) - result = postResult; - if (exception is not null) - throw exception; - return statusCode; + return (statusCode, responseBody); } /// @@ -1389,4 +1557,4 @@ public static DateTime UnixTimeStampToDateTime(long unixTimeStamp) return dateTime; } } -} \ No newline at end of file +} diff --git a/MinecraftClient/Protocol/ProxiedWebRequest.cs b/MinecraftClient/Protocol/ProxiedWebRequest.cs index 220ac3191a..7911bc329e 100644 --- a/MinecraftClient/Protocol/ProxiedWebRequest.cs +++ b/MinecraftClient/Protocol/ProxiedWebRequest.cs @@ -3,6 +3,8 @@ using System.Net; using System.Net.Http; using System.Text; +using System.Threading; +using System.Threading.Tasks; using MinecraftClient.Proxy; namespace MinecraftClient.Protocol @@ -72,6 +74,12 @@ private void SetupBasicHeaders() /// public Response Get() => Send(HttpMethod.Get); + /// + /// Perform GET request asynchronously. Proxy is handled automatically. + /// + public Task GetAsync(CancellationToken cancellationToken = default) => + SendAsync(HttpMethod.Get, cancellationToken: cancellationToken); + /// /// Perform POST request. Proxy is handled automatically. /// @@ -79,6 +87,14 @@ private void SetupBasicHeaders() /// Request body public Response Post(string contentType, string body) => Send(HttpMethod.Post, contentType, body); + /// + /// Perform POST request asynchronously. Proxy is handled automatically. + /// + /// The content type of request body + /// Request body + public Task PostAsync(string contentType, string body, CancellationToken cancellationToken = default) => + SendAsync(HttpMethod.Post, contentType, body, cancellationToken); + /// /// Send an HTTP request. Proxy is configured automatically from Settings. /// @@ -144,6 +160,66 @@ private Response Send(HttpMethod method, string? contentType = null, string? bod } } + /// + /// Send an HTTP request asynchronously. Proxy is configured automatically from Settings. + /// + private async Task SendAsync(HttpMethod method, string? contentType = null, string? body = null, CancellationToken cancellationToken = default) + { + using var handler = CreateHandler(); + using var client = new HttpClient(handler); + + using var request = new HttpRequestMessage(method, _uri); + + foreach (string key in Headers) + { + if (key.Equals("Content-Type", StringComparison.OrdinalIgnoreCase) || + key.Equals("Content-Length", StringComparison.OrdinalIgnoreCase) || + key.Equals("Host", StringComparison.OrdinalIgnoreCase)) + continue; + + request.Headers.TryAddWithoutValidation(key, Headers[key]); + } + + if (body is not null) + request.Content = new StringContent(body, Encoding.UTF8, contentType ?? "text/plain"); + + if (Debug) + { + ConsoleIO.WriteLine($"< {method} {_uri}"); + foreach (string key in Headers) + ConsoleIO.WriteLine($"< {key}: {Headers[key]}"); + } + + try + { + using var httpResponse = await client.SendAsync(request, cancellationToken); + string responseBody = await httpResponse.Content.ReadAsStringAsync(cancellationToken); + + var responseHeaders = new NameValueCollection(); + foreach (var header in httpResponse.Headers) + foreach (var val in header.Value) + responseHeaders.Add(header.Key.ToLowerInvariant(), val); + foreach (var header in httpResponse.Content.Headers) + foreach (var val in header.Value) + responseHeaders.Add(header.Key.ToLowerInvariant(), val); + + var cookies = new NameValueCollection(); + foreach (Cookie cookie in handler.CookieContainer.GetCookies(_uri)) + { + if (!cookie.Expired) + cookies.Add(cookie.Name, cookie.Value); + } + + return new Response((int)httpResponse.StatusCode, responseBody, responseHeaders, cookies); + } + catch (HttpRequestException ex) + { + if (Debug) + ConsoleIO.WriteLine("HTTP error: " + ex.Message); + return Response.Empty(); + } + } + /// /// Create a SocketsHttpHandler with proxy support from ProxyHandler settings. /// @@ -231,4 +307,4 @@ public override string ToString() } } } -} \ No newline at end of file +} diff --git a/MinecraftClient/Protocol/Session/SessionToken.cs b/MinecraftClient/Protocol/Session/SessionToken.cs index 1364012bc8..244c82a403 100644 --- a/MinecraftClient/Protocol/Session/SessionToken.cs +++ b/MinecraftClient/Protocol/Session/SessionToken.cs @@ -54,6 +54,16 @@ public bool SessionPreCheck(LoginType type) return false; } + public async Task SessionPreCheckAsync(LoginType type) + { + if (ID == string.Empty || PlayerID == String.Empty || ServerPublicKey is null) + return false; + + Crypto.CryptoHandler.ClientAESPrivateKey ??= Crypto.CryptoHandler.GenerateAESPrivateKey(); + string serverHash = Crypto.CryptoHandler.GetServerHash(ServerIDhash, ServerPublicKey, Crypto.CryptoHandler.ClientAESPrivateKey); + return await ProtocolHandler.SessionCheckAsync(PlayerID, ID, serverHash, type); + } + public override string ToString() { return String.Join(",", ID, PlayerName, PlayerID, ClientID, RefreshToken, ServerIDhash, diff --git a/MinecraftClient/Scripting/MccGameApi.cs b/MinecraftClient/Scripting/MccGameApi.cs index 391ece5024..4bd18f1bf5 100644 --- a/MinecraftClient/Scripting/MccGameApi.cs +++ b/MinecraftClient/Scripting/MccGameApi.cs @@ -704,9 +704,76 @@ public MccGameResult MoveToPlayer(string playerName, bool /// /// Run on a worker thread so ChatBot callbacks can poll the result without blocking MCC updates. /// - public Task> MoveToPlayerAsync(string playerName, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) + public async Task> MoveToPlayerAsync(string playerName, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) { - return Task.Run(() => MoveToPlayer(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs)); + if (string.IsNullOrWhiteSpace(playerName)) + return MccGameResult.Fail("invalid_args"); + + if (!AreValidPathOffsets(maxOffset, minOffset) || timeoutMs < 0) + return MccGameResult.Fail("invalid_args"); + + McClient? client = clientProvider(); + if (client is null) + return NotConnected(); + + if (!client.GetTerrainEnabled() || !client.GetEntityHandlingEnabled()) + return MccGameResult.Fail("feature_disabled"); + + string nameFilter = playerName.Trim(); + NearbyPlayerSnapshot? target = client.InvokeOnMainThread(() => + { + return BuildTrackedPlayerSnapshots(client, includeSelf: false) + .Where(player => PlayerNameMatches(player, nameFilter)) + .OrderBy(player => player.Distance) + .FirstOrDefault(); + }); + + if (target is null) + return MccGameResult.Fail("invalid_state"); + + Location goal = new(target.X, target.Y, target.Z); + Location startLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + TimeSpan? timeout = timeoutMs > 0 ? TimeSpan.FromMilliseconds(timeoutMs) : null; + bool pathFound = client.InvokeOnMainThread(() => client.MoveTo(goal, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeout)); + + int verifyWaitMs = GetArrivalWaitMs(timeoutMs); + double tolerance = GetArrivalTolerance(maxOffset, minOffset); + Location finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + bool arrived = false; + + if (pathFound) + { + (arrived, finalLocation) = await WaitForArrivalAsync(client, goal, verifyWaitMs, tolerance); + } + + MccMoveToPlayerResult resultData = new() + { + PathFound = pathFound, + Arrived = arrived, + Tolerance = tolerance, + VerifyWaitMs = verifyWaitMs, + Target = new MccMoveToPlayerTarget + { + PlayerName = target.Name, + EntityId = target.EntityId, + X = MccGameCommon.RoundCoordinate(target.X), + Y = MccGameCommon.RoundCoordinate(target.Y), + Z = MccGameCommon.RoundCoordinate(target.Z) + }, + StartLocation = MccGameCommon.ToCoordinate(startLocation), + FinalLocation = MccGameCommon.ToCoordinate(finalLocation), + FinalDistance = MccGameCommon.GetDistance(finalLocation, goal), + DistanceMoved = MccGameCommon.GetDistance(startLocation, finalLocation), + AllowUnsafe = allowUnsafe, + AllowDirectTeleport = allowDirectTeleport, + MaxOffset = maxOffset, + MinOffset = minOffset, + TimeoutMs = timeoutMs + }; + + return pathFound && arrived + ? MccGameResult.Ok(resultData) + : MccGameResult.Fail("action_incomplete", data: resultData); } /// @@ -1055,9 +1122,94 @@ public MccGameResult PickupItems(string itemType, double r /// /// Run on a worker thread so ChatBot callbacks can poll the result without blocking MCC updates. /// - public Task> PickupItemsAsync(string itemType, double radius = 16, int maxItems = 10, bool allowUnsafe = false, int timeoutMs = 0) + public async Task> PickupItemsAsync(string itemType, double radius = 16, int maxItems = 10, bool allowUnsafe = false, int timeoutMs = 0) { - return Task.Run(() => PickupItems(itemType, radius, maxItems, allowUnsafe, timeoutMs)); + if (string.IsNullOrWhiteSpace(itemType) || radius <= 0 || radius > 1024 || maxItems < 1 || timeoutMs < 0) + return MccGameResult.Fail("invalid_args"); + + if (!MccGameCommon.TryParseItemType(itemType.Trim(), out ItemType parsedItemType)) + return MccGameResult.Fail("invalid_args"); + + McClient? client = clientProvider(); + if (client is null) + return NotConnected(); + + if (!client.GetTerrainEnabled() || !client.GetEntityHandlingEnabled()) + return MccGameResult.Fail("feature_disabled"); + + int limit = Math.Clamp(maxItems, 1, 50); + NearbyItemSnapshot[] targets = client.InvokeOnMainThread(() => BuildNearbyItemSnapshots(client, parsedItemType, radius, limit)); + if (targets.Length == 0) + return MccGameResult.Fail("invalid_state"); + + bool inventoryEnabled = client.GetInventoryEnabled(); + int beforeCount = inventoryEnabled ? client.InvokeOnMainThread(() => GetInventoryItemCount(client, parsedItemType)) : 0; + int initialCount = beforeCount; + int verifyWaitMs = timeoutMs > 0 ? Math.Clamp(timeoutMs, MinArrivalWaitMs, MaxArrivalWaitMs) : 2500; + List attempts = new(targets.Length); + int successfulPickups = 0; + + foreach (NearbyItemSnapshot target in targets) + { + Location targetLocation = new(target.X, target.Y, target.Z); + Location startLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + TimeSpan? moveTimeout = timeoutMs > 0 ? TimeSpan.FromMilliseconds(timeoutMs) : null; + bool pathFound = client.InvokeOnMainThread(() => client.MoveTo(targetLocation, allowUnsafe, false, 0, 0, moveTimeout)); + + Location finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + bool arrived = false; + if (pathFound) + { + (arrived, finalLocation) = await WaitForArrivalAsync(client, targetLocation, verifyWaitMs, 2.0); + } + + bool entityGone = await WaitForEntityRemovalAsync(client, target.EntityId, verifyWaitMs); + int afterCount = inventoryEnabled ? client.InvokeOnMainThread(() => GetInventoryItemCount(client, parsedItemType)) : beforeCount; + int inventoryDelta = inventoryEnabled ? Math.Max(0, afterCount - beforeCount) : 0; + bool pickedUp = entityGone || inventoryDelta > 0; + if (pickedUp) + successfulPickups++; + + attempts.Add(new MccPickupAttempt + { + EntityId = target.EntityId, + ItemType = target.ItemType.ToString(), + TypeLabel = target.TypeLabel, + ExpectedCount = target.Count, + Target = MccGameCommon.ToCoordinate(target.X, target.Y, target.Z), + PathFound = pathFound, + Arrived = arrived, + EntityGone = entityGone, + InventoryDelta = inventoryDelta, + StartLocation = MccGameCommon.ToCoordinate(startLocation), + FinalLocation = MccGameCommon.ToCoordinate(finalLocation), + FinalDistance = MccGameCommon.GetDistance(finalLocation, targetLocation) + }); + + beforeCount = afterCount; + } + + int remainingNearby = client.InvokeOnMainThread(() => BuildNearbyItemSnapshots(client, parsedItemType, radius, 1000).Length); + int collectedCount = inventoryEnabled ? Math.Max(0, beforeCount - initialCount) : successfulPickups; + MccPickupItemsResult resultData = new() + { + ItemType = parsedItemType.ToString(), + Radius = radius, + MaxItems = limit, + AllowUnsafe = allowUnsafe, + TimeoutMs = verifyWaitMs, + Attempted = attempts.Count, + SuccessfulPickups = successfulPickups, + CollectedCount = collectedCount, + InitialInventoryCount = inventoryEnabled ? initialCount : null, + FinalInventoryCount = inventoryEnabled ? beforeCount : null, + RemainingNearby = remainingNearby, + Attempts = attempts.ToArray() + }; + + return successfulPickups > 0 + ? MccGameResult.Ok(resultData) + : MccGameResult.Fail("action_incomplete", data: resultData); } private static MccGameResult NotConnected() @@ -1132,6 +1284,24 @@ private static bool WaitForArrival(McClient client, Location goal, int waitMs, d } } + private static async Task<(bool Arrived, Location FinalLocation)> WaitForArrivalAsync(McClient client, Location goal, int waitMs, double tolerance) + { + DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); + Location finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + + while (true) + { + finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + if (MccGameCommon.GetDistance(finalLocation, goal) <= tolerance) + return (true, finalLocation); + + if (DateTime.UtcNow >= deadline) + return (false, finalLocation); + + await Task.Delay(ArrivalPollIntervalMs); + } + } + private static bool WaitForEntityRemoval(McClient client, int entityId, int waitMs) { DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); @@ -1148,6 +1318,22 @@ private static bool WaitForEntityRemoval(McClient client, int entityId, int wait } } + private static async Task WaitForEntityRemovalAsync(McClient client, int entityId, int waitMs) + { + DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); + while (true) + { + bool exists = client.InvokeOnMainThread(() => client.GetEntities().ContainsKey(entityId)); + if (!exists) + return true; + + if (DateTime.UtcNow >= deadline) + return false; + + await Task.Delay(ArrivalPollIntervalMs); + } + } + private static int GetInventoryItemCount(McClient client, ItemType itemType) { Container? inventory = client.GetInventory(0); diff --git a/MinecraftClient/TaskWithResult.cs b/MinecraftClient/TaskWithResult.cs index 53aec3aa4f..23b0dbc1e0 100644 --- a/MinecraftClient/TaskWithResult.cs +++ b/MinecraftClient/TaskWithResult.cs @@ -1,20 +1,24 @@ using System; using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient { + internal interface IMainThreadTask + { + void ExecuteSynchronously(); + void Cancel(); + } + /// /// Holds an asynchronous task with return value /// /// Type of the return value - public class TaskWithResult + public sealed class TaskWithResult : IMainThreadTask { - private readonly AutoResetEvent resultEvent = new(false); private readonly Func task; - private T? result = default; - private Exception? exception = null; - private bool taskRun = false; - private readonly Lock taskRunLock = new(); + private readonly TaskCompletionSource completionSource = new(TaskCreationOptions.RunContinuationsAsynchronously); + private int taskState; /// /// Create a new asynchronous task with return value @@ -28,13 +32,7 @@ public TaskWithResult(Func task) /// /// Check whether the task has finished running /// - public bool HasRun - { - get - { - return taskRun; - } - } + public bool HasRun => completionSource.Task.IsCompleted; /// /// Get the task result (return value of the inner delegate) @@ -44,10 +42,10 @@ public T Result { get { - if (taskRun) - return result!; - else + if (!completionSource.Task.IsCompleted) throw new InvalidOperationException("Attempting to retrieve the result of an unfinished task"); + + return completionSource.Task.GetAwaiter().GetResult(); } } @@ -58,40 +56,39 @@ public Exception? Exception { get { - return exception; + return completionSource.Task.Exception?.InnerException; } } + public Task AsTask() + { + return completionSource.Task; + } + /// /// Execute the task in the current thread and set the property or to the returned value /// public void ExecuteSynchronously() { - // Make sur the task will not run twice - lock (taskRunLock) - { - if (taskRun) - { - throw new InvalidOperationException("Attempting to run a task twice"); - } - } + if (Interlocked.CompareExchange(ref taskState, 1, 0) != 0) + throw new InvalidOperationException("Attempting to run a task twice"); - // Run the task try { - result = task(); + completionSource.TrySetResult(task()); } catch (Exception e) { - exception = e; + completionSource.TrySetException(e); } + } - // Mark task as complete and release wait event - lock (taskRunLock) - { - taskRun = true; - } - resultEvent.Set(); + public void Cancel() + { + if (Interlocked.CompareExchange(ref taskState, 1, 0) != 0) + return; + + completionSource.TrySetException(new OperationCanceledException("Main-thread task was canceled before execution.")); } /// @@ -101,22 +98,7 @@ public void ExecuteSynchronously() /// Any exception thrown by the task public T WaitGetResult() { - // Wait only if the result is not available yet - bool mustWait = false; - lock (taskRunLock) - { - mustWait = !taskRun; - } - if (mustWait) - { - resultEvent.WaitOne(); - } - - // Receive exception from task - if (exception is not null) - throw exception; - - return result!; + return completionSource.Task.GetAwaiter().GetResult(); } } } diff --git a/MinecraftClient/Tui/TuiConsoleBackend.cs b/MinecraftClient/Tui/TuiConsoleBackend.cs index bee5975116..7a12df8927 100644 --- a/MinecraftClient/Tui/TuiConsoleBackend.cs +++ b/MinecraftClient/Tui/TuiConsoleBackend.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Runtime.InteropServices; using System.Threading; +using System.Threading.Tasks; using Avalonia; using Avalonia.Threading; using Consolonia; @@ -48,12 +49,11 @@ internal void RunTuiMainLoop(string[] args, Program.StartupState startupState) Dispatcher.UIThread.Post(() => view.HandleCtrlC()); }; - new Thread(() => + _ = Task.Run(() => { _viewReady.Wait(); ContinueMccStartup(args); - }) - { Name = "MCC-Main", IsBackground = true }.Start(); + }); AppBuilder builder = AppBuilder.Configure() .UseConsolonia() @@ -206,32 +206,49 @@ internal void DismissOverlay() } public string RequestImmediateInput() + { + return RequestImmediateInputAsync(CancellationToken.None).GetAwaiter().GetResult(); + } + + public Task RequestImmediateInputAsync(CancellationToken cancellationToken) { if (_shutdownRequested) { - Thread.Sleep(Timeout.Infinite); - return string.Empty; + return Task.FromCanceled(cancellationToken.CanBeCanceled + ? cancellationToken + : new CancellationToken(canceled: true)); } - var mre = new ManualResetEventSlim(false); - string? result = null; + TaskCompletionSource completion = new(TaskCreationOptions.RunContinuationsAsynchronously); void Handler(object? sender, string e) { - result = e; - mre.Set(); + MessageReceived -= Handler; + completion.TrySetResult(e); } MessageReceived += Handler; - mre.Wait(); - MessageReceived -= Handler; - return result ?? string.Empty; + if (cancellationToken.CanBeCanceled) + { + cancellationToken.Register(() => + { + MessageReceived -= Handler; + completion.TrySetCanceled(cancellationToken); + }); + } + + return completion.Task; } public string? ReadPassword() { - return RequestImmediateInput(); + return ReadPasswordAsync(CancellationToken.None).GetAwaiter().GetResult(); + } + + public async Task ReadPasswordAsync(CancellationToken cancellationToken) + { + return await RequestImmediateInputAsync(cancellationToken); } public void ClearInputBuffer() @@ -267,11 +284,11 @@ public void Shutdown() Dispatcher.UIThread.Post(() => lifetime.Shutdown()); } - new Thread(() => + _ = Task.Run(async () => { - Thread.Sleep(1000); + await Task.Delay(1000); Environment.Exit(0); - }) { Name = "TUI-Exit-Guard", IsBackground = true }.Start(); + }); } private volatile bool _shutdownRequested;