1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::future::Future;
15 use std::panic;
16 use std::ptr::NonNull;
17 use std::task::{Context, Poll, Waker};
18 
19 use crate::error::{ErrorKind, ScheduleError};
20 use crate::executor::Schedule;
21 use crate::task::raw::{Header, Inner, TaskMngInfo};
22 use crate::task::state;
23 use crate::task::state::StateAction;
24 use crate::task::waker::WakerRefHeader;
25 
26 cfg_not_ffrt! {
27     use crate::task::Task;
28 }
29 
30 pub(crate) struct TaskHandle<T: Future, S: Schedule> {
31     task: NonNull<TaskMngInfo<T, S>>,
32 }
33 
34 impl<T, S> TaskHandle<T, S>
35 where
36     T: Future,
37     S: Schedule,
38 {
39     pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Self {
40         TaskHandle {
41             task: ptr.cast::<TaskMngInfo<T, S>>(),
42         }
43     }
44 
headernull45     fn header(&self) -> &Header {
46         unsafe { self.task.as_ref().header() }
47     }
48 
innernull49     fn inner(&self) -> &Inner<T, S> {
50         unsafe { self.task.as_ref().inner() }
51     }
52 }
53 
54 impl<T, S> TaskHandle<T, S>
55 where
56     T: Future,
57     S: Schedule,
58 {
finishnull59     fn finish(self, state: usize, output: Result<T::Output, ScheduleError>) {
60         // send result if the JoinHandle is not dropped
61         if state::is_care_join_handle(state) {
62             self.inner().send_result(output);
63         } else {
64             self.inner().turning_to_used_data();
65         }
66 
67         let cur = match self.header().state.turning_to_finish() {
68             Ok(cur) => cur,
69             Err(e) => panic!("{}", e.as_str()),
70         };
71 
72         if state::is_set_waker(cur) {
73             self.inner().wake_join();
74         }
75         self.drop_ref();
76     }
77 
78     pub(crate) fn release(self) {
79         unsafe { drop(Box::from_raw(self.task.as_ptr())) };
80     }
81 
82     pub(crate) fn drop_ref(self) {
83         let prev = self.header().state.dec_ref();
84         if state::is_last_ref_count(prev) {
85             self.release();
86         }
87     }
88 
89     pub(crate) fn get_result(self, out: &mut Poll<std::result::Result<T::Output, ScheduleError>>) {
90         *out = Poll::Ready(self.inner().turning_to_get_data());
91     }
92 
93     pub(crate) fn drop_join_handle(self) {
94         if self.header().state.try_turning_to_un_join_handle() {
95             return;
96         }
97 
98         match self.header().state.turn_to_un_join_handle() {
99             Ok(_) => {}
100             Err(_) => {
101                 self.inner().turning_to_used_data();
102             }
103         }
104         self.drop_ref();
105     }
106 
set_waker_innernull107     fn set_waker_inner(&self, des_waker: Waker, cur_state: usize) -> Result<usize, usize> {
108         assert!(
109             state::is_care_join_handle(cur_state),
110             "set waker failed: the join handle has been dropped"
111         );
112         assert!(
113             !state::is_set_waker(cur_state),
114             "set waker failed: the task already has a waker set"
115         );
116 
117         unsafe {
118             let waker = self.inner().waker.get();
119             *waker = Some(des_waker);
120         }
121         let result = self.header().state.turn_to_set_waker();
122         if result.is_err() {
123             unsafe {
124                 let waker = self.inner().waker.get();
125                 *waker = None;
126             }
127         }
128         result
129     }
130 
131     pub(crate) fn set_waker(self, cur: usize, des_waker: &Waker) -> bool {
132         let res = if state::is_set_waker(cur) {
133             let is_same_waker = unsafe {
134                 // the status is set_waker, so waker must be set already
135                 let waker = self.inner().waker.get();
136                 (*waker)
137                     .as_ref()
138                     .expect("task status is set_waker, but waker is missing")
139                     .will_wake(des_waker)
140             };
141             // we don't register the same waker
142             if is_same_waker {
143                 return false;
144             }
145             self.header()
146                 .state
147                 .turn_to_un_set_waker()
148                 .and_then(|cur| self.set_waker_inner(des_waker.clone(), cur))
149         } else {
150             self.set_waker_inner(des_waker.clone(), cur)
151         };
152 
153         if let Err(cur) = res {
154             assert!(
155                 state::is_finished(cur),
156                 "setting waker should only be failed dur to task completion"
157             );
158             return true;
159         }
160 
161         false
162     }
163 }
164 
165 #[cfg(not(feature = "ffrt"))]
166 impl<T, S> TaskHandle<T, S>
167 where
168     T: Future,
169     S: Schedule,
170 {
171     // Runs the task
172     pub(crate) fn run(self) {
173         let action = self.header().state.turning_to_running();
174 
175         match action {
176             StateAction::Success => {}
177             StateAction::Canceled(cur) => {
178                 let output = self.get_canceled();
179                 return self.finish(cur, Err(output));
180             }
181             StateAction::Failed(state) => panic!("task state invalid: {state}"),
182             _ => unreachable!(),
183         };
184 
185         // turn the task header into a waker
186         let waker = WakerRefHeader::<'_>::new::<T>(self.header());
187         let mut context = Context::from_waker(&waker);
188 
189         let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
190             self.inner().poll(&mut context).map(Ok)
191         }));
192 
193         let cur = self.header().state.get_current_state();
194         match res {
195             Ok(Poll::Ready(output)) => {
196                 // send result if the JoinHandle is not dropped
197                 self.finish(cur, output);
198             }
199 
200             Ok(Poll::Pending) => match self.header().state.turning_to_idle() {
201                 StateAction::Enqueue => {
202                     self.get_scheduled(true);
203                 }
204                 StateAction::Failed(state) => panic!("task state invalid: {state}"),
205                 StateAction::Canceled(state) => {
206                     let output = self.get_canceled();
207                     self.finish(state, Err(output));
208                 }
209                 _ => {}
210             },
211 
212             Err(_) => {
213                 let output = Err(ScheduleError::new(ErrorKind::Panic, "panic happen"));
214                 self.finish(cur, output);
215             }
216         }
217     }
218 
219     pub(crate) unsafe fn shutdown(self) {
220         // Check if the JoinHandle gets dropped already. If JoinHandle is still there,
221         // wakes the JoinHandle.
222         let cur = self.header().state.dec_ref();
223         if state::ref_count(cur) > 0 && state::is_care_join_handle(cur) {
224             self.set_canceled();
225         } else {
226             self.release();
227         }
228     }
229 
230     pub(crate) fn wake(self) {
231         self.wake_by_ref();
232         self.drop_ref();
233     }
234 
235     pub(crate) fn wake_by_ref(&self) {
236         let prev = self.header().state.turn_to_scheduling();
237         if state::need_enqueue(prev) {
238             self.get_scheduled(false);
239         }
240     }
241 
242     // Actually cancels the task during running
get_cancelednull243     fn get_canceled(&self) -> ScheduleError {
244         self.inner().turning_to_used_data();
245         ErrorKind::TaskCanceled.into()
246     }
247 
248     // Sets task state into canceled and scheduled
249     pub(crate) fn set_canceled(&self) {
250         if self.header().state.turn_to_canceled_and_scheduled() {
251             self.get_scheduled(false);
252         }
253     }
254 
to_tasknull255     fn to_task(&self) -> Task {
256         unsafe { Task::from_raw(self.header().into()) }
257     }
258 
get_schedulednull259     fn get_scheduled(&self, lifo: bool) {
260         // the scheduler must exist when calling this method
261         self.inner()
262             .scheduler
263             .upgrade()
264             .expect("the scheduler has already been dropped")
265             .schedule(self.to_task(), lifo);
266     }
267 }
268 
269 #[cfg(feature = "ffrt")]
270 impl<T, S> TaskHandle<T, S>
271 where
272     T: Future,
273     S: Schedule,
274 {
275     pub(crate) fn ffrt_run(self) -> bool {
276         self.inner().get_task_ctx();
277 
278         match self.header().state.turning_to_running() {
279             StateAction::Failed(state) => panic!("turning to running failed: {:b}", state),
280             StateAction::Canceled(cur) => {
281                 let output = self.ffrt_get_canceled();
282                 self.finish(cur, Err(output));
283                 return true;
284             }
285             _ => {}
286         }
287 
288         let waker = WakerRefHeader::<'_>::new::<T>(self.header());
289         let mut context = Context::from_waker(&waker);
290 
291         let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
292             self.inner().poll(&mut context).map(Ok)
293         }));
294 
295         let cur = self.header().state.get_current_state();
296         match res {
297             Ok(Poll::Ready(output)) => {
298                 // send result if the JoinHandle is not dropped
299                 self.finish(cur, output);
300                 true
301             }
302 
303             Ok(Poll::Pending) => match self.header().state.turning_to_idle() {
304                 StateAction::Enqueue => {
305                     let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() };
306                     ffrt_task.wake_task();
307                     false
308                 }
309                 StateAction::Failed(state) => panic!("task state invalid: {:b}", state),
310                 StateAction::Canceled(state) => {
311                     let output = self.ffrt_get_canceled();
312                     self.finish(state, Err(output));
313                     true
314                 }
315                 _ => false,
316             },
317 
318             Err(_) => {
319                 let output = Err(ScheduleError::new(ErrorKind::Panic, "panic happen"));
320                 self.finish(cur, output);
321                 true
322             }
323         }
324     }
325 
326     pub(crate) fn ffrt_wake(self) {
327         self.ffrt_wake_by_ref();
328         self.drop_ref();
329     }
330 
331     pub(crate) fn ffrt_wake_by_ref(&self) {
332         let prev = self.header().state.turn_to_scheduling();
333         if state::need_enqueue(prev) {
334             let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() };
335             ffrt_task.wake_task();
336         }
337     }
338 
339     // Actually cancels the task during running
ffrt_get_cancelednull340     fn ffrt_get_canceled(&self) -> ScheduleError {
341         self.inner().turning_to_used_data();
342         ErrorKind::TaskCanceled.into()
343     }
344 
345     // Sets task state into canceled and scheduled
346     pub(crate) fn ffrt_set_canceled(&self) {
347         if self.header().state.turn_to_canceled_and_scheduled() {
348             let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() };
349             ffrt_task.wake_task();
350         }
351     }
352 }
353