1// Copyright (c) 2023 Huawei Device Co., Ltd.
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use std::future::Future;
15use std::panic;
16use std::ptr::NonNull;
17use std::task::{Context, Poll, Waker};
18
19use crate::error::{ErrorKind, ScheduleError};
20use crate::executor::Schedule;
21use crate::task::raw::{Header, Inner, TaskMngInfo};
22use crate::task::state;
23use crate::task::state::StateAction;
24use crate::task::waker::WakerRefHeader;
25
26cfg_not_ffrt! {
27    use crate::task::Task;
28}
29
30pub(crate) struct TaskHandle<T: Future, S: Schedule> {
31    task: NonNull<TaskMngInfo<T, S>>,
32}
33
34impl<T, S> TaskHandle<T, S>
35where
36    T: Future,
37    S: Schedule,
38{
39    pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Self {
40        TaskHandle {
41            task: ptr.cast::<TaskMngInfo<T, S>>(),
42        }
43    }
44
45    fn header(&self) -> &Header {
46        unsafe { self.task.as_ref().header() }
47    }
48
49    fn inner(&self) -> &Inner<T, S> {
50        unsafe { self.task.as_ref().inner() }
51    }
52}
53
54impl<T, S> TaskHandle<T, S>
55where
56    T: Future,
57    S: Schedule,
58{
59    fn finish(self, state: usize, output: Result<T::Output, ScheduleError>) {
60        // send result if the JoinHandle is not dropped
61        if state::is_care_join_handle(state) {
62            self.inner().send_result(output);
63        } else {
64            self.inner().turning_to_used_data();
65        }
66
67        let cur = match self.header().state.turning_to_finish() {
68            Ok(cur) => cur,
69            Err(e) => panic!("{}", e.as_str()),
70        };
71
72        if state::is_set_waker(cur) {
73            self.inner().wake_join();
74        }
75        self.drop_ref();
76    }
77
78    pub(crate) fn release(self) {
79        unsafe { drop(Box::from_raw(self.task.as_ptr())) };
80    }
81
82    pub(crate) fn drop_ref(self) {
83        let prev = self.header().state.dec_ref();
84        if state::is_last_ref_count(prev) {
85            self.release();
86        }
87    }
88
89    pub(crate) fn get_result(self, out: &mut Poll<std::result::Result<T::Output, ScheduleError>>) {
90        *out = Poll::Ready(self.inner().turning_to_get_data());
91    }
92
93    pub(crate) fn drop_join_handle(self) {
94        if self.header().state.try_turning_to_un_join_handle() {
95            return;
96        }
97
98        match self.header().state.turn_to_un_join_handle() {
99            Ok(_) => {}
100            Err(_) => {
101                self.inner().turning_to_used_data();
102            }
103        }
104        self.drop_ref();
105    }
106
107    fn set_waker_inner(&self, des_waker: Waker, cur_state: usize) -> Result<usize, usize> {
108        assert!(
109            state::is_care_join_handle(cur_state),
110            "set waker failed: the join handle has been dropped"
111        );
112        assert!(
113            !state::is_set_waker(cur_state),
114            "set waker failed: the task already has a waker set"
115        );
116
117        unsafe {
118            let waker = self.inner().waker.get();
119            *waker = Some(des_waker);
120        }
121        let result = self.header().state.turn_to_set_waker();
122        if result.is_err() {
123            unsafe {
124                let waker = self.inner().waker.get();
125                *waker = None;
126            }
127        }
128        result
129    }
130
131    pub(crate) fn set_waker(self, cur: usize, des_waker: &Waker) -> bool {
132        let res = if state::is_set_waker(cur) {
133            let is_same_waker = unsafe {
134                // the status is set_waker, so waker must be set already
135                let waker = self.inner().waker.get();
136                (*waker)
137                    .as_ref()
138                    .expect("task status is set_waker, but waker is missing")
139                    .will_wake(des_waker)
140            };
141            // we don't register the same waker
142            if is_same_waker {
143                return false;
144            }
145            self.header()
146                .state
147                .turn_to_un_set_waker()
148                .and_then(|cur| self.set_waker_inner(des_waker.clone(), cur))
149        } else {
150            self.set_waker_inner(des_waker.clone(), cur)
151        };
152
153        if let Err(cur) = res {
154            assert!(
155                state::is_finished(cur),
156                "setting waker should only be failed dur to task completion"
157            );
158            return true;
159        }
160
161        false
162    }
163}
164
165#[cfg(not(feature = "ffrt"))]
166impl<T, S> TaskHandle<T, S>
167where
168    T: Future,
169    S: Schedule,
170{
171    // Runs the task
172    pub(crate) fn run(self) {
173        let action = self.header().state.turning_to_running();
174
175        match action {
176            StateAction::Success => {}
177            StateAction::Canceled(cur) => {
178                let output = self.get_canceled();
179                return self.finish(cur, Err(output));
180            }
181            StateAction::Failed(state) => panic!("task state invalid: {state}"),
182            _ => unreachable!(),
183        };
184
185        // turn the task header into a waker
186        let waker = WakerRefHeader::<'_>::new::<T>(self.header());
187        let mut context = Context::from_waker(&waker);
188
189        let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
190            self.inner().poll(&mut context).map(Ok)
191        }));
192
193        let cur = self.header().state.get_current_state();
194        match res {
195            Ok(Poll::Ready(output)) => {
196                // send result if the JoinHandle is not dropped
197                self.finish(cur, output);
198            }
199
200            Ok(Poll::Pending) => match self.header().state.turning_to_idle() {
201                StateAction::Enqueue => {
202                    self.get_scheduled(true);
203                }
204                StateAction::Failed(state) => panic!("task state invalid: {state}"),
205                StateAction::Canceled(state) => {
206                    let output = self.get_canceled();
207                    self.finish(state, Err(output));
208                }
209                _ => {}
210            },
211
212            Err(_) => {
213                let output = Err(ScheduleError::new(ErrorKind::Panic, "panic happen"));
214                self.finish(cur, output);
215            }
216        }
217    }
218
219    pub(crate) unsafe fn shutdown(self) {
220        // Check if the JoinHandle gets dropped already. If JoinHandle is still there,
221        // wakes the JoinHandle.
222        let cur = self.header().state.dec_ref();
223        if state::ref_count(cur) > 0 && state::is_care_join_handle(cur) {
224            self.set_canceled();
225        } else {
226            self.release();
227        }
228    }
229
230    pub(crate) fn wake(self) {
231        self.wake_by_ref();
232        self.drop_ref();
233    }
234
235    pub(crate) fn wake_by_ref(&self) {
236        let prev = self.header().state.turn_to_scheduling();
237        if state::need_enqueue(prev) {
238            self.get_scheduled(false);
239        }
240    }
241
242    // Actually cancels the task during running
243    fn get_canceled(&self) -> ScheduleError {
244        self.inner().turning_to_used_data();
245        ErrorKind::TaskCanceled.into()
246    }
247
248    // Sets task state into canceled and scheduled
249    pub(crate) fn set_canceled(&self) {
250        if self.header().state.turn_to_canceled_and_scheduled() {
251            self.get_scheduled(false);
252        }
253    }
254
255    fn to_task(&self) -> Task {
256        unsafe { Task::from_raw(self.header().into()) }
257    }
258
259    fn get_scheduled(&self, lifo: bool) {
260        // the scheduler must exist when calling this method
261        self.inner()
262            .scheduler
263            .upgrade()
264            .expect("the scheduler has already been dropped")
265            .schedule(self.to_task(), lifo);
266    }
267}
268
269#[cfg(feature = "ffrt")]
270impl<T, S> TaskHandle<T, S>
271where
272    T: Future,
273    S: Schedule,
274{
275    pub(crate) fn ffrt_run(self) -> bool {
276        self.inner().get_task_ctx();
277
278        match self.header().state.turning_to_running() {
279            StateAction::Failed(state) => panic!("turning to running failed: {:b}", state),
280            StateAction::Canceled(cur) => {
281                let output = self.ffrt_get_canceled();
282                self.finish(cur, Err(output));
283                return true;
284            }
285            _ => {}
286        }
287
288        let waker = WakerRefHeader::<'_>::new::<T>(self.header());
289        let mut context = Context::from_waker(&waker);
290
291        let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
292            self.inner().poll(&mut context).map(Ok)
293        }));
294
295        let cur = self.header().state.get_current_state();
296        match res {
297            Ok(Poll::Ready(output)) => {
298                // send result if the JoinHandle is not dropped
299                self.finish(cur, output);
300                true
301            }
302
303            Ok(Poll::Pending) => match self.header().state.turning_to_idle() {
304                StateAction::Enqueue => {
305                    let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() };
306                    ffrt_task.wake_task();
307                    false
308                }
309                StateAction::Failed(state) => panic!("task state invalid: {:b}", state),
310                StateAction::Canceled(state) => {
311                    let output = self.ffrt_get_canceled();
312                    self.finish(state, Err(output));
313                    true
314                }
315                _ => false,
316            },
317
318            Err(_) => {
319                let output = Err(ScheduleError::new(ErrorKind::Panic, "panic happen"));
320                self.finish(cur, output);
321                true
322            }
323        }
324    }
325
326    pub(crate) fn ffrt_wake(self) {
327        self.ffrt_wake_by_ref();
328        self.drop_ref();
329    }
330
331    pub(crate) fn ffrt_wake_by_ref(&self) {
332        let prev = self.header().state.turn_to_scheduling();
333        if state::need_enqueue(prev) {
334            let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() };
335            ffrt_task.wake_task();
336        }
337    }
338
339    // Actually cancels the task during running
340    fn ffrt_get_canceled(&self) -> ScheduleError {
341        self.inner().turning_to_used_data();
342        ErrorKind::TaskCanceled.into()
343    }
344
345    // Sets task state into canceled and scheduled
346    pub(crate) fn ffrt_set_canceled(&self) {
347        if self.header().state.turn_to_canceled_and_scheduled() {
348            let ffrt_task = unsafe { (*self.inner().task.get()).as_ref().unwrap() };
349            ffrt_task.wake_task();
350        }
351    }
352}
353