From cf3424984a03587199e4779dd2340941783f3097 Mon Sep 17 00:00:00 2001 From: Tian L <60599517+tian-lt@users.noreply.github.com> Date: Tue, 9 Apr 2024 15:17:15 +0800 Subject: [PATCH] Fix: heap corruption in the Graphing mode (#2152) * Run render pass asynchronously on UI thread * remove unused fields * async for SetVariable * resolve comments --- src/GraphControl/Control/Grapher.cpp | 117 +++++++++--------------- src/GraphControl/Control/Grapher.h | 3 +- src/GraphControl/DirectX/RenderMain.cpp | 92 ++++++------------- src/GraphControl/DirectX/RenderMain.h | 22 +---- 4 files changed, 71 insertions(+), 163 deletions(-) diff --git a/src/GraphControl/Control/Grapher.cpp b/src/GraphControl/Control/Grapher.cpp index 564188e4..81191a75 100644 --- a/src/GraphControl/Control/Grapher.cpp +++ b/src/GraphControl/Control/Grapher.cpp @@ -11,7 +11,6 @@ using namespace GraphControl; using namespace GraphControl::DX; using namespace Platform; using namespace Platform::Collections; -using namespace std; using namespace Concurrency; using namespace Windows::Devices::Input; using namespace Windows::Foundation; @@ -54,9 +53,9 @@ namespace // posX/posY are the pointer position elements and width,height are the dimensions of the graph surface. // The graphing engine interprets x,y position between the range [-1, 1]. // Translate the pointer position to the [-1, 1] bounds. - __inline pair PointerPositionToGraphPosition(double posX, double posY, double width, double height) + __inline std::pair PointerPositionToGraphPosition(double posX, double posY, double width, double height) { - return make_pair((2 * posX / width - 1), (1 - 2 * posY / height)); + return { (2 * posX / width - 1), (1 - 2 * posY / height) }; } } @@ -71,7 +70,7 @@ namespace GraphControl m_solver->ParsingOptions().SetFormatType(s_defaultFormatType); m_solver->FormatOptions().SetFormatType(s_defaultFormatType); - m_solver->FormatOptions().SetMathMLPrefix(wstring(L"mml")); + m_solver->FormatOptions().SetMathMLPrefix(L"mml"); DefaultStyleKey = StringReference(s_defaultStyleKey); @@ -95,21 +94,10 @@ namespace GraphControl { if (m_graph != nullptr && m_renderMain != nullptr) { - if (auto renderer = m_graph->GetRenderer()) + if (auto renderer = m_graph->GetRenderer(); static_cast(renderer) && SUCCEEDED(renderer->ScaleRange(centerX, centerY, scale))) { - m_renderMain->GetCriticalSection().lock(); - - if (SUCCEEDED(renderer->ScaleRange(centerX, centerY, scale))) - { - m_renderMain->GetCriticalSection().unlock(); - - m_renderMain->RunRenderPass(); - GraphViewChangedEvent(this, GraphViewChangedReason::Manipulation); - } - else - { - m_renderMain->GetCriticalSection().unlock(); - } + m_renderMain->RunRenderPass(); + GraphViewChangedEvent(this, GraphViewChangedReason::Manipulation); } } } @@ -251,7 +239,7 @@ namespace GraphControl if (auto analyzer = graph->GetAnalyzer()) { - vector equationVector; + std::vector equationVector; equationVector.push_back(equation); UpdateGraphOptions(graph->GetOptions(), equationVector); bool variableIsNotX; @@ -325,15 +313,15 @@ namespace GraphControl task Grapher::TryUpdateGraph(bool keepCurrentView) { - optional>> initResult = nullopt; + std::optional>> initResult; bool successful = false; m_errorCode = 0; m_errorType = 0; if (m_renderMain && m_graph != nullptr) { - unique_ptr graphExpression; - wstring request; + std::unique_ptr graphExpression; + std::wstring request; auto validEqs = GetGraphableEquations(); @@ -371,13 +359,12 @@ namespace GraphControl co_return false; } - unique_ptr expr; - wstring parsableEquation = s_getGraphOpeningTags; + std::wstring parsableEquation = s_getGraphOpeningTags; parsableEquation += equationRequest; parsableEquation += s_getGraphClosingTags; // Wire up the corresponding error to an error message in the UI at some point - if (!(expr = m_solver->ParseInput(parsableEquation, m_errorCode, m_errorType))) + if (auto expr = m_solver->ParseInput(parsableEquation, m_errorCode, m_errorType); !static_cast(expr)) { co_return false; } @@ -392,11 +379,11 @@ namespace GraphControl { initResult = TryInitializeGraph(keepCurrentView, graphExpression.get()); - if (initResult != nullopt) + if (initResult.has_value()) { - auto graphedEquations = initResult.value(); + auto& graphedEquations = *initResult; - for (int i = 0; i < validEqs.size(); i++) + for (size_t i = 0; i < validEqs.size(); ++i) { validEqs[i]->GraphedEquation = graphedEquations[i]; } @@ -407,8 +394,8 @@ namespace GraphControl m_renderMain->Graph = m_graph; // It is possible that the render fails, in that case fall through to explicit empty initialization - co_await m_renderMain->RunRenderPassAsync(false); - if (m_renderMain->IsRenderPassSuccesful()) + auto succ = co_await m_renderMain->RunRenderPassAsync(false); + if (succ) { UpdateVariables(); successful = true; @@ -417,7 +404,7 @@ namespace GraphControl { // If we failed to render then we have already lost the previous graph shouldKeepPreviousGraph = false; - initResult = nullopt; + initResult.reset(); m_solver->HRErrorToErrorInfo(m_renderMain->GetRenderError(), m_errorCode, m_errorType); } } @@ -427,15 +414,15 @@ namespace GraphControl } } - if (initResult == nullopt) + if (!initResult.has_value()) { // Do not re-initialize the graph to empty if there are still valid equations graphed if (!shouldKeepPreviousGraph) { initResult = TryInitializeGraph(false, nullptr); - if (initResult != nullopt) + if (initResult.has_value()) { - UpdateGraphOptions(m_graph->GetOptions(), vector()); + UpdateGraphOptions(m_graph->GetOptions(), {}); SetGraphArgs(m_graph); m_renderMain->Graph = m_graph; @@ -475,12 +462,10 @@ namespace GraphControl } } - void Grapher::SetGraphArgs(shared_ptr graph) + void Grapher::SetGraphArgs(std::shared_ptr graph) { if (graph != nullptr && m_renderMain != nullptr) { - critical_section::scoped_lock lock(m_renderMain->GetCriticalSection()); - for (auto variablePair : Variables) { graph->SetArgValue(variablePair->Key->Data(), variablePair->Value->Value); @@ -488,17 +473,17 @@ namespace GraphControl } } - shared_ptr Grapher::GetGraph(Equation ^ equation) + std::shared_ptr Grapher::GetGraph(Equation ^ equation) { - shared_ptr graph = m_solver->CreateGrapher(); + std::shared_ptr graph = m_solver->CreateGrapher(); - wstring request = s_getGraphOpeningTags; + std::wstring request = s_getGraphOpeningTags; request += equation->GetRequest()->Data(); request += s_getGraphClosingTags; - if (unique_ptr graphExpression = m_solver->ParseInput(request, m_errorCode, m_errorType)) + if (auto expr = m_solver->ParseInput(request, m_errorCode, m_errorType); static_cast(expr)) { - if (graph->TryInitialize(graphExpression.get())) + if (graph->TryInitialize(expr.get())) { return graph; } @@ -557,19 +542,12 @@ namespace GraphControl if (m_graph != nullptr && m_renderMain != nullptr) { - auto workItemHandler = ref new WorkItemHandler([this, variableName, newValue](IAsyncAction ^ action) { - m_renderMain->GetCriticalSection().lock(); - m_graph->SetArgValue(variableName->Data(), newValue); - m_renderMain->GetCriticalSection().unlock(); - - m_renderMain->RunRenderPass(); - }); - - ThreadPool::RunAsync(workItemHandler, WorkItemPriority::High, WorkItemOptions::None); + m_graph->SetArgValue(variableName->Data(), newValue); + [](RenderMain ^ renderMain) -> winrt::fire_and_forget { co_await renderMain->RunRenderPassAsync(); }(m_renderMain); } } - void Grapher::UpdateGraphOptions(IGraphingOptions& options, const vector& validEqs) + void Grapher::UpdateGraphOptions(IGraphingOptions& options, const std::vector& validEqs) { options.SetForceProportional(ForceProportionalAxes); @@ -580,7 +558,7 @@ namespace GraphControl if (!validEqs.empty()) { - vector graphColors; + std::vector graphColors; graphColors.reserve(validEqs.size()); for (Equation ^ eq : validEqs) { @@ -595,17 +573,17 @@ namespace GraphControl } eq->GraphedEquation->GetGraphEquationOptions()->SetLineStyle(static_cast<::Graphing::Renderer::LineStyle>(eq->EquationStyle)); - eq->GraphedEquation->GetGraphEquationOptions()->SetLineWidth(LineWidth); - eq->GraphedEquation->GetGraphEquationOptions()->SetSelectedEquationLineWidth(LineWidth + ((LineWidth <= 2) ? 1 : 2)); + eq->GraphedEquation->GetGraphEquationOptions()->SetLineWidth(static_cast(LineWidth)); + eq->GraphedEquation->GetGraphEquationOptions()->SetSelectedEquationLineWidth(static_cast(LineWidth + ((LineWidth <= 2) ? 1 : 2))); } } options.SetGraphColors(graphColors); } } - vector Grapher::GetGraphableEquations() + std::vector Grapher::GetGraphableEquations() { - vector validEqs; + std::vector validEqs; for (Equation ^ eq : Equations) { @@ -783,15 +761,10 @@ namespace GraphControl translationX /= -width; translationY /= height; - m_renderMain->GetCriticalSection().lock(); - if (FAILED(renderer->MoveRangeByRatio(translationX, translationY))) { - m_renderMain->GetCriticalSection().unlock(); return; } - - m_renderMain->GetCriticalSection().unlock(); needsRenderPass = true; } @@ -805,15 +778,10 @@ namespace GraphControl const auto& pos = e->Position; const auto [centerX, centerY] = PointerPositionToGraphPosition(pos.X, pos.Y, width, height); - m_renderMain->GetCriticalSection().lock(); - if (FAILED(renderer->ScaleRange(centerX, centerY, scale))) { - m_renderMain->GetCriticalSection().unlock(); return; } - - m_renderMain->GetCriticalSection().unlock(); needsRenderPass = true; } @@ -834,14 +802,14 @@ namespace GraphControl { if (auto renderer = m_graph->GetRenderer()) { - shared_ptr BitmapOut; + std::shared_ptr BitmapOut; bool hasSomeMissingDataOut = false; HRESULT hr = E_FAIL; hr = renderer->GetBitmap(BitmapOut, hasSomeMissingDataOut); if (SUCCEEDED(hr)) { // Get the raw data - vector byteVector = BitmapOut->GetData(); + std::vector byteVector = BitmapOut->GetData(); auto arr = ArrayReference(byteVector.data(), (unsigned int)byteVector.size()); // create a memory stream wrapper @@ -1085,13 +1053,13 @@ void Grapher::OnGraphBackgroundPropertyChanged(Windows::UI::Color /*oldValue*/, } } -void Grapher::OnGridLinesColorPropertyChanged(Windows::UI::Color /*oldValue*/, Windows::UI::Color newValue) +winrt::fire_and_forget Grapher::OnGridLinesColorPropertyChanged(Windows::UI::Color /*oldValue*/, Windows::UI::Color newValue) { if (m_renderMain != nullptr && m_graph != nullptr) { auto gridLinesColor = Graphing::Color(newValue.R, newValue.G, newValue.B, newValue.A); m_graph->GetOptions().SetGridColor(gridLinesColor); - m_renderMain->RunRenderPassAsync(); + co_await m_renderMain->RunRenderPassAsync(); } } @@ -1102,7 +1070,7 @@ void Grapher::OnLineWidthPropertyChanged(double oldValue, double newValue) UpdateGraphOptions(m_graph->GetOptions(), GetGraphableEquations()); if (m_renderMain) { - m_renderMain->SetPointRadius(LineWidth + 1); + m_renderMain->SetPointRadius(static_cast(LineWidth + 1)); m_renderMain->RunRenderPass(); TraceLogger::GetInstance()->LogLineWidthChanged(); @@ -1110,16 +1078,15 @@ void Grapher::OnLineWidthPropertyChanged(double oldValue, double newValue) } } -optional>> Grapher::TryInitializeGraph(bool keepCurrentView, const IExpression* graphingExp) +std::optional>> Grapher::TryInitializeGraph(bool keepCurrentView, const IExpression* graphingExp) { - critical_section::scoped_lock lock(m_renderMain->GetCriticalSection()); if (keepCurrentView || IsKeepCurrentView) { auto renderer = m_graph->GetRenderer(); double xMin, xMax, yMin, yMax; renderer->GetDisplayRanges(xMin, xMax, yMin, yMax); auto initResult = m_graph->TryInitialize(graphingExp); - if (initResult != nullopt) + if (initResult.has_value()) { if (IsKeepCurrentView) { diff --git a/src/GraphControl/Control/Grapher.h b/src/GraphControl/Control/Grapher.h index a4de85b4..26f306e3 100644 --- a/src/GraphControl/Control/Grapher.h +++ b/src/GraphControl/Control/Grapher.h @@ -228,7 +228,6 @@ public enum class GraphViewChangedReason { if (auto render = m_graph->GetRenderer()) { - Concurrency::critical_section::scoped_lock lock(m_renderMain->GetCriticalSection()); render->GetDisplayRanges(*xMin, *xMax, *yMin, *yMax); } } @@ -280,7 +279,7 @@ public enum class GraphViewChangedReason void OnEquationsPropertyChanged(EquationCollection ^ oldValue, EquationCollection ^ newValue); void OnAxesColorPropertyChanged(Windows::UI::Color oldValue, Windows::UI::Color newValue); void OnGraphBackgroundPropertyChanged(Windows::UI::Color oldValue, Windows::UI::Color newValue); - void OnGridLinesColorPropertyChanged(Windows::UI::Color /*oldValue*/, Windows::UI::Color newValue); + winrt::fire_and_forget OnGridLinesColorPropertyChanged(Windows::UI::Color /*oldValue*/, Windows::UI::Color newValue); void OnLineWidthPropertyChanged(double oldValue, double newValue); void OnEquationChanged(Equation ^ equation); void OnEquationStyleChanged(Equation ^ equation); diff --git a/src/GraphControl/DirectX/RenderMain.cpp b/src/GraphControl/DirectX/RenderMain.cpp index 1e53821b..fa5ddb77 100644 --- a/src/GraphControl/DirectX/RenderMain.cpp +++ b/src/GraphControl/DirectX/RenderMain.cpp @@ -30,11 +30,10 @@ namespace namespace GraphControl::DX { RenderMain::RenderMain(SwapChainPanel ^ panel) - : m_deviceResources{ panel } - , m_nearestPointRenderer{ &m_deviceResources } - , m_backgroundColor{ {} } - , m_swapChainPanel{ panel } - , m_TraceLocation(Point(0, 0)) + : m_deviceResources(panel) + , m_nearestPointRenderer(&m_deviceResources) + , m_swapChainPanel(panel) + , m_TraceLocation(Point{ 0, 0 }) , m_Tracing(false) { // Register to be notified if the Device is lost or recreated @@ -97,7 +96,7 @@ namespace GraphControl::DX bool wasPointRendered = m_Tracing; if (CanRenderPoint() || wasPointRendered) { - RunRenderPassAsync(); + [](RenderMain ^ self) -> winrt::fire_and_forget { co_await self->RunRenderPassAsync(); }(this); } } } @@ -111,7 +110,7 @@ namespace GraphControl::DX bool wasPointRendered = m_Tracing; if (CanRenderPoint() || wasPointRendered) { - RunRenderPassAsync(); + [](RenderMain ^ self) -> winrt::fire_and_forget { co_await self->RunRenderPassAsync(); }(this); } } } @@ -146,15 +145,6 @@ namespace GraphControl::DX trackPoint = m_activeTracingPointerLocation; } - if (!m_criticalSection.try_lock()) - { - return false; - } - - m_criticalSection.unlock(); - - critical_section::scoped_lock lock(m_criticalSection); - int formulaId = -1; double outNearestPointValueX, outNearestPointValueY; float outNearestPointLocationX, outNearestPointLocationY; @@ -210,67 +200,37 @@ namespace GraphControl::DX bool RenderMain::RunRenderPass() { - // Non async render passes cancel if they can't obtain the lock immediatly - if (!m_criticalSection.try_lock()) - { - return false; - } - - m_criticalSection.unlock(); - - critical_section::scoped_lock lock(m_criticalSection); - return RunRenderPassInternal(); } - IAsyncAction ^ RenderMain::RunRenderPassAsync(bool allowCancel) + concurrency::task RenderMain::RunRenderPassAsync(bool allowCancel) { - // Try to cancel the renderPass that is in progress - if (m_renderPass != nullptr && m_renderPass->Status == ::AsyncStatus::Started) - { - m_renderPass->Cancel(); - } - - auto device = m_deviceResources; - auto workItemHandler = ref new WorkItemHandler([this, allowCancel](IAsyncAction ^ action) { - critical_section::scoped_lock lock(m_criticalSection); - - // allowCancel is passed as false when the grapher relies on the render pass to validate that an equation can be succesfully rendered. - // Passing false garauntees that another render pass doesn't cancel this one. - if (allowCancel && action->Status == ::AsyncStatus::Canceled) - { - return; - } - - RunRenderPassInternal(); - }); - - m_renderPass = ThreadPool::RunAsync(workItemHandler, WorkItemPriority::High, WorkItemOptions::None); - - return m_renderPass; + bool result = false; + auto currentVer = ++m_renderPassVer; + Platform::WeakReference that{ this }; + co_await m_coreWindow->Dispatcher->RunAsync( + CoreDispatcherPriority::High, + ref new DispatchedHandler( + [&] + { + auto self = that.Resolve(); + if (self == nullptr || (allowCancel && m_renderPassVer != currentVer)) + { + return; + } + result = self->RunRenderPassInternal(); + })); + co_return result; } bool RenderMain::RunRenderPassInternal() { - // We are accessing Direct3D resources directly without Direct2D's knowledge, so we - // must manually acquire and apply the Direct2D factory lock. - ID2D1Multithread* m_D2DMultithread; - m_deviceResources.GetD2DFactory()->QueryInterface(IID_PPV_ARGS(&m_D2DMultithread)); - m_D2DMultithread->Enter(); - - bool succesful = Render(); - - if (succesful) + if (Render()) { m_deviceResources.Present(); + return true; } - - // It is absolutely critical that the factory lock be released upon - // exiting this function, or else any consequent Direct2D calls will be blocked. - m_D2DMultithread->Leave(); - - m_isRenderPassSuccesful = succesful; - return m_isRenderPassSuccesful; + return false; } // Renders the current frame according to the current application state. diff --git a/src/GraphControl/DirectX/RenderMain.h b/src/GraphControl/DirectX/RenderMain.h index be021ba5..6e4050aa 100644 --- a/src/GraphControl/DirectX/RenderMain.h +++ b/src/GraphControl/DirectX/RenderMain.h @@ -51,17 +51,7 @@ namespace GraphControl::DX bool RunRenderPass(); - Windows::Foundation::IAsyncAction ^ RunRenderPassAsync(bool allowCancel = true); - - Concurrency::critical_section& GetCriticalSection() - { - return m_criticalSection; - } - - bool IsRenderPassSuccesful() - { - return m_isRenderPassSuccesful; - } + concurrency::task RunRenderPassAsync(bool allowCancel = true); HRESULT GetRenderError(); @@ -185,10 +175,6 @@ namespace GraphControl::DX Windows::Foundation::EventRegistrationToken m_tokenOrientationChanged; Windows::Foundation::EventRegistrationToken m_tokenDisplayContentsInvalidated; - // Track our independent input on a background worker thread. - Windows::Foundation::IAsyncAction ^ m_inputLoopWorker = nullptr; - Windows::UI::Core::CoreIndependentInputSource ^ m_coreInput = nullptr; - double m_XTraceValue; double m_YTraceValue; @@ -198,11 +184,7 @@ namespace GraphControl::DX // Are we currently showing the tracing value bool m_Tracing; - Concurrency::critical_section m_criticalSection; - - Windows::Foundation::IAsyncAction ^ m_renderPass = nullptr; - - bool m_isRenderPassSuccesful; + unsigned m_renderPassVer = 0; HRESULT m_HResult; };