1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * AMD Secure Encrypted Virtualization (SEV) guest driver interface
4 *
5 * Copyright (C) 2021 Advanced Micro Devices, Inc.
6 *
7 * Author: Brijesh Singh <brijesh.singh@amd.com>
8 */
9
10#include <linux/module.h>
11#include <linux/kernel.h>
12#include <linux/types.h>
13#include <linux/mutex.h>
14#include <linux/io.h>
15#include <linux/platform_device.h>
16#include <linux/miscdevice.h>
17#include <linux/set_memory.h>
18#include <linux/fs.h>
19#include <crypto/aead.h>
20#include <linux/scatterlist.h>
21#include <linux/psp-sev.h>
22#include <uapi/linux/sev-guest.h>
23#include <uapi/linux/psp-sev.h>
24
25#include <asm/svm.h>
26#include <asm/sev.h>
27
28#include "sev-guest.h"
29
30#define DEVICE_NAME	"sev-guest"
31#define AAD_LEN		48
32#define MSG_HDR_VER	1
33
34#define SNP_REQ_MAX_RETRY_DURATION	(60*HZ)
35#define SNP_REQ_RETRY_DELAY		(2*HZ)
36
37struct snp_guest_crypto {
38	struct crypto_aead *tfm;
39	u8 *iv, *authtag;
40	int iv_len, a_len;
41};
42
43struct snp_guest_dev {
44	struct device *dev;
45	struct miscdevice misc;
46
47	void *certs_data;
48	struct snp_guest_crypto *crypto;
49	/* request and response are in unencrypted memory */
50	struct snp_guest_msg *request, *response;
51
52	/*
53	 * Avoid information leakage by double-buffering shared messages
54	 * in fields that are in regular encrypted memory.
55	 */
56	struct snp_guest_msg secret_request, secret_response;
57
58	struct snp_secrets_page_layout *layout;
59	struct snp_req_data input;
60	union {
61		struct snp_report_req report;
62		struct snp_derived_key_req derived_key;
63		struct snp_ext_report_req ext_report;
64	} req;
65	u32 *os_area_msg_seqno;
66	u8 *vmpck;
67};
68
69static u32 vmpck_id;
70module_param(vmpck_id, uint, 0444);
71MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.");
72
73/* Mutex to serialize the shared buffer access and command handling. */
74static DEFINE_MUTEX(snp_cmd_mutex);
75
76static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
77{
78	char zero_key[VMPCK_KEY_LEN] = {0};
79
80	if (snp_dev->vmpck)
81		return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
82
83	return true;
84}
85
86/*
87 * If an error is received from the host or AMD Secure Processor (ASP) there
88 * are two options. Either retry the exact same encrypted request or discontinue
89 * using the VMPCK.
90 *
91 * This is because in the current encryption scheme GHCB v2 uses AES-GCM to
92 * encrypt the requests. The IV for this scheme is the sequence number. GCM
93 * cannot tolerate IV reuse.
94 *
95 * The ASP FW v1.51 only increments the sequence numbers on a successful
96 * guest<->ASP back and forth and only accepts messages at its exact sequence
97 * number.
98 *
99 * So if the sequence number were to be reused the encryption scheme is
100 * vulnerable. If the sequence number were incremented for a fresh IV the ASP
101 * will reject the request.
102 */
103static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
104{
105	dev_alert(snp_dev->dev, "Disabling vmpck_id %d to prevent IV reuse.\n",
106		  vmpck_id);
107	memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
108	snp_dev->vmpck = NULL;
109}
110
111static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
112{
113	u64 count;
114
115	lockdep_assert_held(&snp_cmd_mutex);
116
117	/* Read the current message sequence counter from secrets pages */
118	count = *snp_dev->os_area_msg_seqno;
119
120	return count + 1;
121}
122
123/* Return a non-zero on success */
124static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
125{
126	u64 count = __snp_get_msg_seqno(snp_dev);
127
128	/*
129	 * The message sequence counter for the SNP guest request is a  64-bit
130	 * value but the version 2 of GHCB specification defines a 32-bit storage
131	 * for it. If the counter exceeds the 32-bit value then return zero.
132	 * The caller should check the return value, but if the caller happens to
133	 * not check the value and use it, then the firmware treats zero as an
134	 * invalid number and will fail the  message request.
135	 */
136	if (count >= UINT_MAX) {
137		dev_err(snp_dev->dev, "request message sequence counter overflow\n");
138		return 0;
139	}
140
141	return count;
142}
143
144static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
145{
146	/*
147	 * The counter is also incremented by the PSP, so increment it by 2
148	 * and save in secrets page.
149	 */
150	*snp_dev->os_area_msg_seqno += 2;
151}
152
153static inline struct snp_guest_dev *to_snp_dev(struct file *file)
154{
155	struct miscdevice *dev = file->private_data;
156
157	return container_of(dev, struct snp_guest_dev, misc);
158}
159
160static struct snp_guest_crypto *init_crypto(struct snp_guest_dev *snp_dev, u8 *key, size_t keylen)
161{
162	struct snp_guest_crypto *crypto;
163
164	crypto = kzalloc(sizeof(*crypto), GFP_KERNEL_ACCOUNT);
165	if (!crypto)
166		return NULL;
167
168	crypto->tfm = crypto_alloc_aead("gcm(aes)", 0, 0);
169	if (IS_ERR(crypto->tfm))
170		goto e_free;
171
172	if (crypto_aead_setkey(crypto->tfm, key, keylen))
173		goto e_free_crypto;
174
175	crypto->iv_len = crypto_aead_ivsize(crypto->tfm);
176	crypto->iv = kmalloc(crypto->iv_len, GFP_KERNEL_ACCOUNT);
177	if (!crypto->iv)
178		goto e_free_crypto;
179
180	if (crypto_aead_authsize(crypto->tfm) > MAX_AUTHTAG_LEN) {
181		if (crypto_aead_setauthsize(crypto->tfm, MAX_AUTHTAG_LEN)) {
182			dev_err(snp_dev->dev, "failed to set authsize to %d\n", MAX_AUTHTAG_LEN);
183			goto e_free_iv;
184		}
185	}
186
187	crypto->a_len = crypto_aead_authsize(crypto->tfm);
188	crypto->authtag = kmalloc(crypto->a_len, GFP_KERNEL_ACCOUNT);
189	if (!crypto->authtag)
190		goto e_free_iv;
191
192	return crypto;
193
194e_free_iv:
195	kfree(crypto->iv);
196e_free_crypto:
197	crypto_free_aead(crypto->tfm);
198e_free:
199	kfree(crypto);
200
201	return NULL;
202}
203
204static void deinit_crypto(struct snp_guest_crypto *crypto)
205{
206	crypto_free_aead(crypto->tfm);
207	kfree(crypto->iv);
208	kfree(crypto->authtag);
209	kfree(crypto);
210}
211
212static int enc_dec_message(struct snp_guest_crypto *crypto, struct snp_guest_msg *msg,
213			   u8 *src_buf, u8 *dst_buf, size_t len, bool enc)
214{
215	struct snp_guest_msg_hdr *hdr = &msg->hdr;
216	struct scatterlist src[3], dst[3];
217	DECLARE_CRYPTO_WAIT(wait);
218	struct aead_request *req;
219	int ret;
220
221	req = aead_request_alloc(crypto->tfm, GFP_KERNEL);
222	if (!req)
223		return -ENOMEM;
224
225	/*
226	 * AEAD memory operations:
227	 * +------ AAD -------+------- DATA -----+---- AUTHTAG----+
228	 * |  msg header      |  plaintext       |  hdr->authtag  |
229	 * | bytes 30h - 5Fh  |    or            |                |
230	 * |                  |   cipher         |                |
231	 * +------------------+------------------+----------------+
232	 */
233	sg_init_table(src, 3);
234	sg_set_buf(&src[0], &hdr->algo, AAD_LEN);
235	sg_set_buf(&src[1], src_buf, hdr->msg_sz);
236	sg_set_buf(&src[2], hdr->authtag, crypto->a_len);
237
238	sg_init_table(dst, 3);
239	sg_set_buf(&dst[0], &hdr->algo, AAD_LEN);
240	sg_set_buf(&dst[1], dst_buf, hdr->msg_sz);
241	sg_set_buf(&dst[2], hdr->authtag, crypto->a_len);
242
243	aead_request_set_ad(req, AAD_LEN);
244	aead_request_set_tfm(req, crypto->tfm);
245	aead_request_set_callback(req, 0, crypto_req_done, &wait);
246
247	aead_request_set_crypt(req, src, dst, len, crypto->iv);
248	ret = crypto_wait_req(enc ? crypto_aead_encrypt(req) : crypto_aead_decrypt(req), &wait);
249
250	aead_request_free(req);
251	return ret;
252}
253
254static int __enc_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg,
255			 void *plaintext, size_t len)
256{
257	struct snp_guest_crypto *crypto = snp_dev->crypto;
258	struct snp_guest_msg_hdr *hdr = &msg->hdr;
259
260	memset(crypto->iv, 0, crypto->iv_len);
261	memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
262
263	return enc_dec_message(crypto, msg, plaintext, msg->payload, len, true);
264}
265
266static int dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg,
267		       void *plaintext, size_t len)
268{
269	struct snp_guest_crypto *crypto = snp_dev->crypto;
270	struct snp_guest_msg_hdr *hdr = &msg->hdr;
271
272	/* Build IV with response buffer sequence number */
273	memset(crypto->iv, 0, crypto->iv_len);
274	memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
275
276	return enc_dec_message(crypto, msg, msg->payload, plaintext, len, false);
277}
278
279static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
280{
281	struct snp_guest_crypto *crypto = snp_dev->crypto;
282	struct snp_guest_msg *resp = &snp_dev->secret_response;
283	struct snp_guest_msg *req = &snp_dev->secret_request;
284	struct snp_guest_msg_hdr *req_hdr = &req->hdr;
285	struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
286
287	dev_dbg(snp_dev->dev, "response [seqno %lld type %d version %d sz %d]\n",
288		resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version, resp_hdr->msg_sz);
289
290	/* Copy response from shared memory to encrypted memory. */
291	memcpy(resp, snp_dev->response, sizeof(*resp));
292
293	/* Verify that the sequence counter is incremented by 1 */
294	if (unlikely(resp_hdr->msg_seqno != (req_hdr->msg_seqno + 1)))
295		return -EBADMSG;
296
297	/* Verify response message type and version number. */
298	if (resp_hdr->msg_type != (req_hdr->msg_type + 1) ||
299	    resp_hdr->msg_version != req_hdr->msg_version)
300		return -EBADMSG;
301
302	/*
303	 * If the message size is greater than our buffer length then return
304	 * an error.
305	 */
306	if (unlikely((resp_hdr->msg_sz + crypto->a_len) > sz))
307		return -EBADMSG;
308
309	/* Decrypt the payload */
310	return dec_payload(snp_dev, resp, payload, resp_hdr->msg_sz + crypto->a_len);
311}
312
313static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type,
314			void *payload, size_t sz)
315{
316	struct snp_guest_msg *req = &snp_dev->secret_request;
317	struct snp_guest_msg_hdr *hdr = &req->hdr;
318
319	memset(req, 0, sizeof(*req));
320
321	hdr->algo = SNP_AEAD_AES_256_GCM;
322	hdr->hdr_version = MSG_HDR_VER;
323	hdr->hdr_sz = sizeof(*hdr);
324	hdr->msg_type = type;
325	hdr->msg_version = version;
326	hdr->msg_seqno = seqno;
327	hdr->msg_vmpck = vmpck_id;
328	hdr->msg_sz = sz;
329
330	/* Verify the sequence number is non-zero */
331	if (!hdr->msg_seqno)
332		return -ENOSR;
333
334	dev_dbg(snp_dev->dev, "request [seqno %lld type %d version %d sz %d]\n",
335		hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
336
337	return __enc_payload(snp_dev, req, payload, sz);
338}
339
340static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
341				  struct snp_guest_request_ioctl *rio)
342{
343	unsigned long req_start = jiffies;
344	unsigned int override_npages = 0;
345	u64 override_err = 0;
346	int rc;
347
348retry_request:
349	/*
350	 * Call firmware to process the request. In this function the encrypted
351	 * message enters shared memory with the host. So after this call the
352	 * sequence number must be incremented or the VMPCK must be deleted to
353	 * prevent reuse of the IV.
354	 */
355	rc = snp_issue_guest_request(exit_code, &snp_dev->input, rio);
356	switch (rc) {
357	case -ENOSPC:
358		/*
359		 * If the extended guest request fails due to having too
360		 * small of a certificate data buffer, retry the same
361		 * guest request without the extended data request in
362		 * order to increment the sequence number and thus avoid
363		 * IV reuse.
364		 */
365		override_npages = snp_dev->input.data_npages;
366		exit_code	= SVM_VMGEXIT_GUEST_REQUEST;
367
368		/*
369		 * Override the error to inform callers the given extended
370		 * request buffer size was too small and give the caller the
371		 * required buffer size.
372		 */
373		override_err = SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN);
374
375		/*
376		 * If this call to the firmware succeeds, the sequence number can
377		 * be incremented allowing for continued use of the VMPCK. If
378		 * there is an error reflected in the return value, this value
379		 * is checked further down and the result will be the deletion
380		 * of the VMPCK and the error code being propagated back to the
381		 * user as an ioctl() return code.
382		 */
383		goto retry_request;
384
385	/*
386	 * The host may return SNP_GUEST_VMM_ERR_BUSY if the request has been
387	 * throttled. Retry in the driver to avoid returning and reusing the
388	 * message sequence number on a different message.
389	 */
390	case -EAGAIN:
391		if (jiffies - req_start > SNP_REQ_MAX_RETRY_DURATION) {
392			rc = -ETIMEDOUT;
393			break;
394		}
395		schedule_timeout_killable(SNP_REQ_RETRY_DELAY);
396		goto retry_request;
397	}
398
399	/*
400	 * Increment the message sequence number. There is no harm in doing
401	 * this now because decryption uses the value stored in the response
402	 * structure and any failure will wipe the VMPCK, preventing further
403	 * use anyway.
404	 */
405	snp_inc_msg_seqno(snp_dev);
406
407	if (override_err) {
408		rio->exitinfo2 = override_err;
409
410		/*
411		 * If an extended guest request was issued and the supplied certificate
412		 * buffer was not large enough, a standard guest request was issued to
413		 * prevent IV reuse. If the standard request was successful, return -EIO
414		 * back to the caller as would have originally been returned.
415		 */
416		if (!rc && override_err == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
417			rc = -EIO;
418	}
419
420	if (override_npages)
421		snp_dev->input.data_npages = override_npages;
422
423	return rc;
424}
425
426static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
427				struct snp_guest_request_ioctl *rio, u8 type,
428				void *req_buf, size_t req_sz, void *resp_buf,
429				u32 resp_sz)
430{
431	u64 seqno;
432	int rc;
433
434	/* Get message sequence and verify that its a non-zero */
435	seqno = snp_get_msg_seqno(snp_dev);
436	if (!seqno)
437		return -EIO;
438
439	/* Clear shared memory's response for the host to populate. */
440	memset(snp_dev->response, 0, sizeof(struct snp_guest_msg));
441
442	/* Encrypt the userspace provided payload in snp_dev->secret_request. */
443	rc = enc_payload(snp_dev, seqno, rio->msg_version, type, req_buf, req_sz);
444	if (rc)
445		return rc;
446
447	/*
448	 * Write the fully encrypted request to the shared unencrypted
449	 * request page.
450	 */
451	memcpy(snp_dev->request, &snp_dev->secret_request,
452	       sizeof(snp_dev->secret_request));
453
454	rc = __handle_guest_request(snp_dev, exit_code, rio);
455	if (rc) {
456		if (rc == -EIO &&
457		    rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
458			return rc;
459
460		dev_alert(snp_dev->dev,
461			  "Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n",
462			  rc, rio->exitinfo2);
463
464		snp_disable_vmpck(snp_dev);
465		return rc;
466	}
467
468	rc = verify_and_dec_payload(snp_dev, resp_buf, resp_sz);
469	if (rc) {
470		dev_alert(snp_dev->dev, "Detected unexpected decode failure from ASP. rc: %d\n", rc);
471		snp_disable_vmpck(snp_dev);
472		return rc;
473	}
474
475	return 0;
476}
477
478static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
479{
480	struct snp_guest_crypto *crypto = snp_dev->crypto;
481	struct snp_report_req *req = &snp_dev->req.report;
482	struct snp_report_resp *resp;
483	int rc, resp_len;
484
485	lockdep_assert_held(&snp_cmd_mutex);
486
487	if (!arg->req_data || !arg->resp_data)
488		return -EINVAL;
489
490	if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
491		return -EFAULT;
492
493	/*
494	 * The intermediate response buffer is used while decrypting the
495	 * response payload. Make sure that it has enough space to cover the
496	 * authtag.
497	 */
498	resp_len = sizeof(resp->data) + crypto->a_len;
499	resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
500	if (!resp)
501		return -ENOMEM;
502
503	rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
504				  SNP_MSG_REPORT_REQ, req, sizeof(*req), resp->data,
505				  resp_len);
506	if (rc)
507		goto e_free;
508
509	if (copy_to_user((void __user *)arg->resp_data, resp, sizeof(*resp)))
510		rc = -EFAULT;
511
512e_free:
513	kfree(resp);
514	return rc;
515}
516
517static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
518{
519	struct snp_derived_key_req *req = &snp_dev->req.derived_key;
520	struct snp_guest_crypto *crypto = snp_dev->crypto;
521	struct snp_derived_key_resp resp = {0};
522	int rc, resp_len;
523	/* Response data is 64 bytes and max authsize for GCM is 16 bytes. */
524	u8 buf[64 + 16];
525
526	lockdep_assert_held(&snp_cmd_mutex);
527
528	if (!arg->req_data || !arg->resp_data)
529		return -EINVAL;
530
531	/*
532	 * The intermediate response buffer is used while decrypting the
533	 * response payload. Make sure that it has enough space to cover the
534	 * authtag.
535	 */
536	resp_len = sizeof(resp.data) + crypto->a_len;
537	if (sizeof(buf) < resp_len)
538		return -ENOMEM;
539
540	if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
541		return -EFAULT;
542
543	rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
544				  SNP_MSG_KEY_REQ, req, sizeof(*req), buf, resp_len);
545	if (rc)
546		return rc;
547
548	memcpy(resp.data, buf, sizeof(resp.data));
549	if (copy_to_user((void __user *)arg->resp_data, &resp, sizeof(resp)))
550		rc = -EFAULT;
551
552	/* The response buffer contains the sensitive data, explicitly clear it. */
553	memzero_explicit(buf, sizeof(buf));
554	memzero_explicit(&resp, sizeof(resp));
555	return rc;
556}
557
558static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
559{
560	struct snp_ext_report_req *req = &snp_dev->req.ext_report;
561	struct snp_guest_crypto *crypto = snp_dev->crypto;
562	struct snp_report_resp *resp;
563	int ret, npages = 0, resp_len;
564
565	lockdep_assert_held(&snp_cmd_mutex);
566
567	if (!arg->req_data || !arg->resp_data)
568		return -EINVAL;
569
570	if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
571		return -EFAULT;
572
573	/* userspace does not want certificate data */
574	if (!req->certs_len || !req->certs_address)
575		goto cmd;
576
577	if (req->certs_len > SEV_FW_BLOB_MAX_SIZE ||
578	    !IS_ALIGNED(req->certs_len, PAGE_SIZE))
579		return -EINVAL;
580
581	if (!access_ok((const void __user *)req->certs_address, req->certs_len))
582		return -EFAULT;
583
584	/*
585	 * Initialize the intermediate buffer with all zeros. This buffer
586	 * is used in the guest request message to get the certs blob from
587	 * the host. If host does not supply any certs in it, then copy
588	 * zeros to indicate that certificate data was not provided.
589	 */
590	memset(snp_dev->certs_data, 0, req->certs_len);
591	npages = req->certs_len >> PAGE_SHIFT;
592cmd:
593	/*
594	 * The intermediate response buffer is used while decrypting the
595	 * response payload. Make sure that it has enough space to cover the
596	 * authtag.
597	 */
598	resp_len = sizeof(resp->data) + crypto->a_len;
599	resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
600	if (!resp)
601		return -ENOMEM;
602
603	snp_dev->input.data_npages = npages;
604	ret = handle_guest_request(snp_dev, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg,
605				   SNP_MSG_REPORT_REQ, &req->data,
606				   sizeof(req->data), resp->data, resp_len);
607
608	/* If certs length is invalid then copy the returned length */
609	if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) {
610		req->certs_len = snp_dev->input.data_npages << PAGE_SHIFT;
611
612		if (copy_to_user((void __user *)arg->req_data, req, sizeof(*req)))
613			ret = -EFAULT;
614	}
615
616	if (ret)
617		goto e_free;
618
619	if (npages &&
620	    copy_to_user((void __user *)req->certs_address, snp_dev->certs_data,
621			 req->certs_len)) {
622		ret = -EFAULT;
623		goto e_free;
624	}
625
626	if (copy_to_user((void __user *)arg->resp_data, resp, sizeof(*resp)))
627		ret = -EFAULT;
628
629e_free:
630	kfree(resp);
631	return ret;
632}
633
634static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long arg)
635{
636	struct snp_guest_dev *snp_dev = to_snp_dev(file);
637	void __user *argp = (void __user *)arg;
638	struct snp_guest_request_ioctl input;
639	int ret = -ENOTTY;
640
641	if (copy_from_user(&input, argp, sizeof(input)))
642		return -EFAULT;
643
644	input.exitinfo2 = 0xff;
645
646	/* Message version must be non-zero */
647	if (!input.msg_version)
648		return -EINVAL;
649
650	mutex_lock(&snp_cmd_mutex);
651
652	/* Check if the VMPCK is not empty */
653	if (is_vmpck_empty(snp_dev)) {
654		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
655		mutex_unlock(&snp_cmd_mutex);
656		return -ENOTTY;
657	}
658
659	switch (ioctl) {
660	case SNP_GET_REPORT:
661		ret = get_report(snp_dev, &input);
662		break;
663	case SNP_GET_DERIVED_KEY:
664		ret = get_derived_key(snp_dev, &input);
665		break;
666	case SNP_GET_EXT_REPORT:
667		ret = get_ext_report(snp_dev, &input);
668		break;
669	default:
670		break;
671	}
672
673	mutex_unlock(&snp_cmd_mutex);
674
675	if (input.exitinfo2 && copy_to_user(argp, &input, sizeof(input)))
676		return -EFAULT;
677
678	return ret;
679}
680
681static void free_shared_pages(void *buf, size_t sz)
682{
683	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
684	int ret;
685
686	if (!buf)
687		return;
688
689	ret = set_memory_encrypted((unsigned long)buf, npages);
690	if (ret) {
691		WARN_ONCE(ret, "failed to restore encryption mask (leak it)\n");
692		return;
693	}
694
695	__free_pages(virt_to_page(buf), get_order(sz));
696}
697
698static void *alloc_shared_pages(struct device *dev, size_t sz)
699{
700	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
701	struct page *page;
702	int ret;
703
704	page = alloc_pages(GFP_KERNEL_ACCOUNT, get_order(sz));
705	if (!page)
706		return NULL;
707
708	ret = set_memory_decrypted((unsigned long)page_address(page), npages);
709	if (ret) {
710		dev_err(dev, "failed to mark page shared, ret=%d\n", ret);
711		__free_pages(page, get_order(sz));
712		return NULL;
713	}
714
715	return page_address(page);
716}
717
718static const struct file_operations snp_guest_fops = {
719	.owner	= THIS_MODULE,
720	.unlocked_ioctl = snp_guest_ioctl,
721};
722
723static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno)
724{
725	u8 *key = NULL;
726
727	switch (id) {
728	case 0:
729		*seqno = &layout->os_area.msg_seqno_0;
730		key = layout->vmpck0;
731		break;
732	case 1:
733		*seqno = &layout->os_area.msg_seqno_1;
734		key = layout->vmpck1;
735		break;
736	case 2:
737		*seqno = &layout->os_area.msg_seqno_2;
738		key = layout->vmpck2;
739		break;
740	case 3:
741		*seqno = &layout->os_area.msg_seqno_3;
742		key = layout->vmpck3;
743		break;
744	default:
745		break;
746	}
747
748	return key;
749}
750
751static int __init sev_guest_probe(struct platform_device *pdev)
752{
753	struct snp_secrets_page_layout *layout;
754	struct sev_guest_platform_data *data;
755	struct device *dev = &pdev->dev;
756	struct snp_guest_dev *snp_dev;
757	struct miscdevice *misc;
758	void __iomem *mapping;
759	int ret;
760
761	if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP))
762		return -ENODEV;
763
764	if (!dev->platform_data)
765		return -ENODEV;
766
767	data = (struct sev_guest_platform_data *)dev->platform_data;
768	mapping = ioremap_encrypted(data->secrets_gpa, PAGE_SIZE);
769	if (!mapping)
770		return -ENODEV;
771
772	layout = (__force void *)mapping;
773
774	ret = -ENOMEM;
775	snp_dev = devm_kzalloc(&pdev->dev, sizeof(struct snp_guest_dev), GFP_KERNEL);
776	if (!snp_dev)
777		goto e_unmap;
778
779	ret = -EINVAL;
780	snp_dev->vmpck = get_vmpck(vmpck_id, layout, &snp_dev->os_area_msg_seqno);
781	if (!snp_dev->vmpck) {
782		dev_err(dev, "invalid vmpck id %d\n", vmpck_id);
783		goto e_unmap;
784	}
785
786	/* Verify that VMPCK is not zero. */
787	if (is_vmpck_empty(snp_dev)) {
788		dev_err(dev, "vmpck id %d is null\n", vmpck_id);
789		goto e_unmap;
790	}
791
792	platform_set_drvdata(pdev, snp_dev);
793	snp_dev->dev = dev;
794	snp_dev->layout = layout;
795
796	/* Allocate the shared page used for the request and response message. */
797	snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
798	if (!snp_dev->request)
799		goto e_unmap;
800
801	snp_dev->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
802	if (!snp_dev->response)
803		goto e_free_request;
804
805	snp_dev->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE);
806	if (!snp_dev->certs_data)
807		goto e_free_response;
808
809	ret = -EIO;
810	snp_dev->crypto = init_crypto(snp_dev, snp_dev->vmpck, VMPCK_KEY_LEN);
811	if (!snp_dev->crypto)
812		goto e_free_cert_data;
813
814	misc = &snp_dev->misc;
815	misc->minor = MISC_DYNAMIC_MINOR;
816	misc->name = DEVICE_NAME;
817	misc->fops = &snp_guest_fops;
818
819	/* initial the input address for guest request */
820	snp_dev->input.req_gpa = __pa(snp_dev->request);
821	snp_dev->input.resp_gpa = __pa(snp_dev->response);
822	snp_dev->input.data_gpa = __pa(snp_dev->certs_data);
823
824	ret =  misc_register(misc);
825	if (ret)
826		goto e_free_cert_data;
827
828	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %d)\n", vmpck_id);
829	return 0;
830
831e_free_cert_data:
832	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
833e_free_response:
834	free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
835e_free_request:
836	free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
837e_unmap:
838	iounmap(mapping);
839	return ret;
840}
841
842static int __exit sev_guest_remove(struct platform_device *pdev)
843{
844	struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev);
845
846	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
847	free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
848	free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
849	deinit_crypto(snp_dev->crypto);
850	misc_deregister(&snp_dev->misc);
851
852	return 0;
853}
854
855/*
856 * This driver is meant to be a common SEV guest interface driver and to
857 * support any SEV guest API. As such, even though it has been introduced
858 * with the SEV-SNP support, it is named "sev-guest".
859 */
860static struct platform_driver sev_guest_driver = {
861	.remove		= __exit_p(sev_guest_remove),
862	.driver		= {
863		.name = "sev-guest",
864	},
865};
866
867module_platform_driver_probe(sev_guest_driver, sev_guest_probe);
868
869MODULE_AUTHOR("Brijesh Singh <brijesh.singh@amd.com>");
870MODULE_LICENSE("GPL");
871MODULE_VERSION("1.0.0");
872MODULE_DESCRIPTION("AMD SEV Guest Driver");
873MODULE_ALIAS("platform:sev-guest");
874