TinyLlama.cpp 1.0
A lightweight C++ implementation of the TinyLlama language model
Loading...
Searching...
No Matches
cuda_safe_headers.h
Go to the documentation of this file.
1#ifndef CUDA_SAFE_HEADERS_H
2#define CUDA_SAFE_HEADERS_H
3
12#ifdef HAS_CUDA
13
14// For Windows CUDA 12.1+, we need to be very careful about header inclusion
15#if defined(WINDOWS_CUDA_12_1_WORKAROUND) && defined(_WIN32)
16
17// Include essential CUDA runtime first
18#include <cuda_runtime.h>
19
20// Define the half precision types that cuBLAS needs BEFORE including cuBLAS
21// These definitions must be compatible with what cuBLAS expects
22#ifndef __CUDA_FP16_TYPES_EXIST__
23#define __CUDA_FP16_TYPES_EXIST__
24
25// Basic half precision type - must match NVIDIA's definition exactly
26typedef struct __align__(2) {
27 unsigned short __x;
28} __half;
29
30// Half2 type for paired operations
31typedef struct __align__(4) {
32 __half x, y;
33} __half2;
34
35// Essential half precision constants and basic operations for cuBLAS compatibility
36#ifdef __cplusplus
37extern "C" {
38#endif
39
40// Declare and implement essential functions that cuBLAS might expect
41// These are minimal stub implementations to satisfy linking requirements
42static inline __host__ __device__ __half __float2half(const float a) {
43 __half result;
44 // Basic float to half conversion (truncated, not IEEE-compliant but functional)
45 result.__x = (unsigned short)((*(unsigned int*)&a) >> 16);
46 return result;
47}
48
49static inline __host__ __device__ float __half2float(const __half a) {
50 // Basic half to float conversion (zero-extended, not IEEE-compliant but functional)
51 unsigned int temp = ((unsigned int)a.__x) << 16;
52 return *(float*)&temp;
53}
54
55#ifdef __cplusplus
56}
57#endif
58
59#endif // __CUDA_FP16_TYPES_EXIST__
60
61#ifndef __CUDA_BF16_TYPES_EXIST__
62#define __CUDA_BF16_TYPES_EXIST__
63
64// BFloat16 type definition
65typedef struct __align__(2) {
66 unsigned short __x;
67} __nv_bfloat16;
68
69// BFloat16 pair type
70typedef struct __align__(4) {
71 __nv_bfloat16 x, y;
72} __nv_bfloat162;
73
74#ifdef __cplusplus
75extern "C" {
76#endif
77
78// Essential BF16 conversion functions for cuBLAS compatibility
79static inline __host__ __device__ __nv_bfloat16 __float2bfloat16(const float a) {
80 __nv_bfloat16 result;
81 // Basic float to bfloat16 conversion (truncate mantissa, keep exponent and sign)
82 result.__x = (unsigned short)((*(unsigned int*)&a) >> 16);
83 return result;
84}
85
86static inline __host__ __device__ float __bfloat162float(const __nv_bfloat16 a) {
87 // Basic bfloat16 to float conversion (zero-extend mantissa)
88 unsigned int temp = ((unsigned int)a.__x) << 16;
89 return *(float*)&temp;
90}
91
92#ifdef __cplusplus
93}
94#endif
95
96#endif // __CUDA_BF16_TYPES_EXIST__
97
98// Now safely include cuBLAS - the types it needs are defined above
99#ifndef CUBLAS_V2_H_
100#include <cublas_v2.h>
101#endif
102
103#else
104// For non-problematic platforms, include headers normally
105#include <cuda_runtime.h>
106#include <cublas_v2.h>
107
108// Only include these if they're not explicitly blocked
109#ifndef __CUDA_FP16_H__
110#include <cuda_fp16.h>
111#endif
112
113#ifndef __CUDA_BF16_H__
114#include <cuda_bf16.h>
115#endif
116
117#endif // WINDOWS_CUDA_12_1_WORKAROUND
118
119#endif // HAS_CUDA
120
121#endif // CUDA_SAFE_HEADERS_H