Revamp ROCm support

This refines where we extract the LLM libraries to by adding a new
OLLAMA_HOME env var, that defaults to `~/.ollama` The logic was already
idempotenent, so this should speed up startups after the first time a
new release is deployed.  It also cleans up after itself.

We now build only a single ROCm version (latest major) on both windows
and linux.  Given the large size of ROCms tensor files, we split the
dependency out.  It's bundled into the installer on windows, and a
separate download on windows.  The linux install script is now smart and
detects the presence of AMD GPUs and looks to see if rocm v6 is already
present, and if not, then downloads our dependency tar file.

For Linux discovery, we now use sysfs and check each GPU against what
ROCm supports so we can degrade to CPU gracefully instead of having
llama.cpp+rocm assert/crash on us.  For Windows, we now use go's windows
dynamic library loading logic to access the amdhip64.dll APIs to query
the GPU information.
This commit is contained in:
Daniel Hiltgen
2024-02-15 17:15:09 -08:00
parent 2e20110e50
commit 6c5ccb11f9
27 changed files with 1091 additions and 588 deletions

View File

@@ -1,101 +0,0 @@
package gpu
import (
"bufio"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strconv"
"strings"
)
// TODO - windows vs. non-windows vs darwin
// Discovery logic for AMD/ROCm GPUs
const (
DriverVersionFile = "/sys/module/amdgpu/version"
GPUPropertiesFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/properties"
// TODO probably break these down per GPU to make the logic simpler
GPUTotalMemoryFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/mem_banks/*/properties" // size_in_bytes line
GPUUsedMemoryFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/mem_banks/*/used_memory"
)
func AMDDetected() bool {
// Some driver versions (older?) don't have a version file, so just lookup the parent dir
sysfsDir := filepath.Dir(DriverVersionFile)
_, err := os.Stat(sysfsDir)
if errors.Is(err, os.ErrNotExist) {
slog.Debug("amd driver not detected " + sysfsDir)
return false
} else if err != nil {
slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
return false
}
return true
}
func AMDDriverVersion() (string, error) {
_, err := os.Stat(DriverVersionFile)
if err != nil {
return "", fmt.Errorf("amdgpu file stat error: %s %w", DriverVersionFile, err)
}
fp, err := os.Open(DriverVersionFile)
if err != nil {
return "", err
}
defer fp.Close()
verString, err := io.ReadAll(fp)
if err != nil {
return "", err
}
return strings.TrimSpace(string(verString)), nil
}
func AMDGFXVersions() []Version {
res := []Version{}
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
for _, match := range matches {
fp, err := os.Open(match)
if err != nil {
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
continue
}
defer fp.Close()
scanner := bufio.NewScanner(fp)
// optionally, resize scanner's capacity for lines over 64K, see next example
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
if len(ver) != 2 || len(ver[1]) < 5 {
slog.Debug("malformed " + line)
continue
}
l := len(ver[1])
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
res = append(res, Version{
Major: uint(major),
Minor: uint(minor),
Patch: uint(patch),
})
}
}
}
return res
}
func (v Version) ToGFXString() string {
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
}

58
gpu/amd_common.go Normal file
View File

@@ -0,0 +1,58 @@
//go:build linux || windows
package gpu
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"strconv"
"strings"
)
// Determine if the given ROCm lib directory is usable by checking for existence of some glob patterns
func rocmLibUsable(libDir string) bool {
slog.Debug("evaluating potential rocm lib dir " + libDir)
for _, g := range ROCmLibGlobs {
res, _ := filepath.Glob(filepath.Join(libDir, g))
if len(res) == 0 {
return false
}
}
return true
}
func GetSupportedGFX(libDir string) ([]string, error) {
var ret []string
files, err := filepath.Glob(filepath.Join(libDir, "rocblas", "library", "TensileLibrary_lazy_gfx*.dat"))
if err != nil {
return nil, err
}
for _, file := range files {
ret = append(ret, strings.TrimSuffix(strings.TrimPrefix(filepath.Base(file), "TensileLibrary_lazy_"), ".dat"))
}
return ret, nil
}
func amdSetVisibleDevices(ids []int, skip map[int]interface{}) {
// Set the visible devices if not already set
// TODO - does sort order matter?
devices := []string{}
for i := range ids {
slog.Debug(fmt.Sprintf("i=%d", i))
if _, skipped := skip[i]; skipped {
slog.Debug("skipped")
continue
}
devices = append(devices, strconv.Itoa(i))
}
slog.Debug(fmt.Sprintf("devices=%v", devices))
val := strings.Join(devices, ",")
err := os.Setenv("HIP_VISIBLE_DEVICES", val)
if err != nil {
slog.Warn(fmt.Sprintf("failed to set env: %s", err))
}
slog.Debug("HIP_VISIBLE_DEVICES=" + val)
}

141
gpu/amd_hip_windows.go Normal file
View File

@@ -0,0 +1,141 @@
package gpu
import (
"fmt"
"log/slog"
"strconv"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
const (
hipSuccess = 0
hipErrorNoDevice = 100
)
type hipDevicePropMinimal struct {
Name [256]byte
unused1 [140]byte
GcnArchName [256]byte // gfx####
iGPU int // Doesn't seem to actually report correctly
unused2 [128]byte
}
// Wrap the amdhip64.dll library for GPU discovery
type HipLib struct {
dll windows.Handle
hipGetDeviceCount uintptr
hipGetDeviceProperties uintptr
hipMemGetInfo uintptr
hipSetDevice uintptr
hipDriverGetVersion uintptr
}
func NewHipLib() (*HipLib, error) {
h, err := windows.LoadLibrary("amdhip64.dll")
if err != nil {
return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
}
hl := &HipLib{}
hl.dll = h
hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount")
if err != nil {
return nil, err
}
hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties")
if err != nil {
return nil, err
}
hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo")
if err != nil {
return nil, err
}
hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice")
if err != nil {
return nil, err
}
hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion")
if err != nil {
return nil, err
}
return hl, nil
}
// The hip library only evaluates the HIP_VISIBLE_DEVICES variable at startup
// so we have to unload/reset the library after we do our initial discovery
// to make sure our updates to that variable are processed by llama.cpp
func (hl *HipLib) Release() {
err := windows.FreeLibrary(hl.dll)
if err != nil {
slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
}
hl.dll = 0
}
func (hl *HipLib) AMDDriverVersion() (string, error) {
if hl.dll == 0 {
return "", fmt.Errorf("dll has been unloaded")
}
var version int
status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version)))
if status != hipSuccess {
return "", fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err)
}
return strconv.Itoa(version), nil
}
func (hl *HipLib) HipGetDeviceCount() int {
if hl.dll == 0 {
slog.Error("dll has been unloaded")
return 0
}
var count int
status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count)))
if status == hipErrorNoDevice {
slog.Info("AMD ROCm reports no devices found")
return 0
}
if status != hipSuccess {
slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
}
return count
}
func (hl *HipLib) HipSetDevice(device int) error {
if hl.dll == 0 {
return fmt.Errorf("dll has been unloaded")
}
status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device))
if status != hipSuccess {
return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err)
}
return nil
}
func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) {
if hl.dll == 0 {
return nil, fmt.Errorf("dll has been unloaded")
}
var props hipDevicePropMinimal
status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device))
if status != hipSuccess {
return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err)
}
return &props, nil
}
// free, total, err
func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) {
if hl.dll == 0 {
return 0, 0, fmt.Errorf("dll has been unloaded")
}
var totalMemory uint64
var freeMemory uint64
status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory)))
if status != hipSuccess {
return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err)
}
return freeMemory, totalMemory, nil
}

411
gpu/amd_linux.go Normal file
View File

@@ -0,0 +1,411 @@
package gpu
import (
"bufio"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"github.com/jmorganca/ollama/version"
)
// Discovery logic for AMD/ROCm GPUs
const (
curlMsg = "curl -fsSL https://github.com/ollama/ollama/releases/download/v%s/rocm-amd64-deps.tgz | tar -zxf - -C %s"
DriverVersionFile = "/sys/module/amdgpu/version"
AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/"
GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"
// Prefix with the node dir
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
RocmStandardLocation = "/opt/rocm/lib"
)
var (
// Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
)
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
// and the user hasn't already set this variable
func AMDGetGPUInfo(resp *GpuInfo) {
// TODO - DRY this out with windows
if !AMDDetected() {
return
}
skip := map[int]interface{}{}
// Opportunistic logging of driver version to aid in troubleshooting
ver, err := AMDDriverVersion()
if err == nil {
slog.Info("AMD Driver: " + ver)
} else {
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
}
// If the user has specified exactly which GPUs to use, look up their memory
visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
if visibleDevices != "" {
ids := []int{}
for _, idStr := range strings.Split(visibleDevices, ",") {
id, err := strconv.Atoi(idStr)
if err != nil {
slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
} else {
ids = append(ids, id)
}
}
amdProcMemLookup(resp, nil, ids)
return
}
// Gather GFX version information from all detected cards
gfx := AMDGFXVersions()
verStrings := []string{}
for i, v := range gfx {
verStrings = append(verStrings, v.ToGFXString())
if v.Major == 0 {
// Silently skip CPUs
skip[i] = struct{}{}
continue
}
if v.Major < 9 {
// TODO consider this a build-time setting if we can support 8xx family GPUs
slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
skip[i] = struct{}{}
}
}
slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
// Abort if all GPUs are skipped
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
libDir, err := AMDValidateLibDir()
if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
return
}
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
if gfxOverride == "" {
supported, err := GetSupportedGFX(libDir)
if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
return
}
slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
for i, v := range gfx {
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
} else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
}
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
ids := make([]int, len(gfx))
i := 0
for k := range gfx {
ids[i] = k
i++
}
amdProcMemLookup(resp, skip, ids)
if resp.memInfo.DeviceCount == 0 {
return
}
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
}
}
// Walk the sysfs nodes for the available GPUs and gather information from them
// skipping over any devices in the skip map
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
if len(ids) == 0 {
slog.Debug("discovering all amdgpu devices")
entries, err := os.ReadDir(AMDNodesSysfsDir)
if err != nil {
slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
return
}
for _, node := range entries {
if !node.IsDir() {
continue
}
id, err := strconv.Atoi(node.Name())
if err != nil {
slog.Warn("malformed amdgpu sysfs node id " + node.Name())
continue
}
ids = append(ids, id)
}
}
slog.Debug(fmt.Sprintf("discovering amdgpu devices %v", ids))
for _, id := range ids {
if _, skipped := skip[id]; skipped {
continue
}
totalMemory := uint64(0)
usedMemory := uint64(0)
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUTotalMemoryFileGlob)
propFiles, err := filepath.Glob(propGlob)
if err != nil {
slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
}
// 1 or more memory banks - sum the values of all of them
for _, propFile := range propFiles {
fp, err := os.Open(propFile)
if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
continue
}
defer fp.Close()
scanner := bufio.NewScanner(fp)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "size_in_bytes") {
ver := strings.Fields(line)
if len(ver) != 2 {
slog.Warn("malformed " + line)
continue
}
bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
if err != nil {
slog.Warn("malformed int " + line)
continue
}
totalMemory += bankSizeInBytes
}
}
}
if totalMemory == 0 {
continue
}
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
usedFiles, err := filepath.Glob(usedGlob)
if err != nil {
slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
continue
}
for _, usedFile := range usedFiles {
fp, err := os.Open(usedFile)
if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
continue
}
defer fp.Close()
data, err := io.ReadAll(fp)
if err != nil {
slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
continue
}
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
if err != nil {
slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
continue
}
usedMemory += used
}
slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %d", id, totalMemory))
slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %d", id, (totalMemory - usedMemory)))
resp.memInfo.DeviceCount++
resp.memInfo.TotalMemory += totalMemory
resp.memInfo.FreeMemory += (totalMemory - usedMemory)
}
if resp.memInfo.DeviceCount > 0 {
resp.Library = "rocm"
}
}
// Quick check for AMD driver so we can skip amdgpu discovery if not present
func AMDDetected() bool {
// Some driver versions (older?) don't have a version file, so just lookup the parent dir
sysfsDir := filepath.Dir(DriverVersionFile)
_, err := os.Stat(sysfsDir)
if errors.Is(err, os.ErrNotExist) {
slog.Debug("amdgpu driver not detected " + sysfsDir)
return false
} else if err != nil {
slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
return false
}
return true
}
func setupLink(source, target string) error {
if err := os.RemoveAll(target); err != nil {
return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
}
if err := os.Symlink(source, target); err != nil {
return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
}
slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
return nil
}
// Ensure the AMD rocm lib dir is wired up
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
// failing that, tell the user how to download it on their own
func AMDValidateLibDir() (string, error) {
// We rely on the rpath compiled into our library to find rocm
// so we establish a symlink to wherever we find it on the system
// to $AssetsDir/rocm
// If we already have a rocm dependency wired, nothing more to do
assetsDir, err := AssetsDir()
if err != nil {
return "", fmt.Errorf("unable to lookup lib dir: %w", err)
}
// Versioned directory
rocmTargetDir := filepath.Join(assetsDir, "rocm")
if rocmLibUsable(rocmTargetDir) {
return rocmTargetDir, nil
}
// Parent dir (unversioned)
commonRocmDir := filepath.Join(filepath.Dir(assetsDir), "rocm")
if rocmLibUsable(commonRocmDir) {
return rocmTargetDir, setupLink(commonRocmDir, rocmTargetDir)
}
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "lib")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
}
}
// Scan the library path for potential matches
ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
for _, ldPath := range ldPaths {
d, err := filepath.Abs(ldPath)
if err != nil {
continue
}
if rocmLibUsable(d) {
return rocmTargetDir, setupLink(d, rocmTargetDir)
}
}
// Well known location(s)
if rocmLibUsable("/opt/rocm/lib") {
return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
}
err = os.MkdirAll(rocmTargetDir, 0755)
if err != nil {
return "", fmt.Errorf("failed to create empty rocm dir %s %w", rocmTargetDir, err)
}
// If we still haven't found a usable rocm, the user will have to download it on their own
slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or run the following")
slog.Warn(fmt.Sprintf(curlMsg, version.Version, rocmTargetDir))
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}
func AMDDriverVersion() (string, error) {
_, err := os.Stat(DriverVersionFile)
if err != nil {
return "", fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
}
fp, err := os.Open(DriverVersionFile)
if err != nil {
return "", err
}
defer fp.Close()
verString, err := io.ReadAll(fp)
if err != nil {
return "", err
}
return strings.TrimSpace(string(verString)), nil
}
func AMDGFXVersions() map[int]Version {
res := map[int]Version{}
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
for _, match := range matches {
fp, err := os.Open(match)
if err != nil {
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
continue
}
defer fp.Close()
i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
if err != nil {
slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
continue
}
scanner := bufio.NewScanner(fp)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
if len(ver) != 2 || len(ver[1]) < 5 {
if ver[1] == "0" {
// Silently skip the CPU
continue
} else {
slog.Debug("malformed " + line)
}
res[i] = Version{
Major: 0,
Minor: 0,
Patch: 0,
}
continue
}
l := len(ver[1])
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
res[i] = Version{
Major: uint(major),
Minor: uint(minor),
Patch: uint(patch),
}
}
}
}
return res
}
func (v Version) ToGFXString() string {
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
}

190
gpu/amd_windows.go Normal file
View File

@@ -0,0 +1,190 @@
package gpu
import (
"bytes"
"fmt"
"log/slog"
"os"
"path/filepath"
"slices"
"strings"
)
const (
RocmStandardLocation = "C:\\Program Files\\AMD\\ROCm\\5.7\\bin" // TODO glob?
// TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true
iGPUName = "AMD Radeon(TM) Graphics"
)
var (
// Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
)
func AMDGetGPUInfo(resp *GpuInfo) {
hl, err := NewHipLib()
if err != nil {
slog.Debug(err.Error())
return
}
defer hl.Release()
skip := map[int]interface{}{}
ids := []int{}
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
ver, err := hl.AMDDriverVersion()
if err == nil {
slog.Info("AMD Driver: " + ver)
} else {
// For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err))
}
// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES
count := hl.HipGetDeviceCount()
if count == 0 {
return
}
libDir, err := AMDValidateLibDir()
if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
return
}
var supported []string
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
if gfxOverride == "" {
supported, err = GetSupportedGFX(libDir)
if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
return
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
slog.Info(fmt.Sprintf("detected %d hip devices", count))
for i := 0; i < count; i++ {
ids = append(ids, i)
err = hl.HipSetDevice(i)
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
skip[i] = struct{}{}
continue
}
props, err := hl.HipGetDeviceProperties(i)
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
skip[i] = struct{}{}
continue
}
n := bytes.IndexByte(props.Name[:], 0)
name := string(props.Name[:n])
slog.Info(fmt.Sprintf("[%d] Name: %s", i, name))
n = bytes.IndexByte(props.GcnArchName[:], 0)
gfx := string(props.GcnArchName[:n])
slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx))
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
// TODO Why isn't props.iGPU accurate!?
if strings.EqualFold(name, iGPUName) {
slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i))
skip[i] = struct{}{}
continue
}
if gfxOverride == "" {
if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported))
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
continue
} else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx))
}
}
totalMemory, freeMemory, err := hl.HipMemGetInfo()
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
continue
}
// TODO according to docs, freeMem may lie on windows!
slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory))
slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory))
resp.memInfo.DeviceCount++
resp.memInfo.TotalMemory += totalMemory
resp.memInfo.FreeMemory += freeMemory
}
if resp.memInfo.DeviceCount > 0 {
resp.Library = "rocm"
}
// Abort if all GPUs are skipped
if len(skip) >= count {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
}
UpdatePath(libDir)
}
func AMDValidateLibDir() (string, error) {
// On windows non-admins typically can't create links
// so instead of trying to rely on rpath and a link in
// $LibDir/rocm, we instead rely on setting PATH to point
// to the location of the ROCm library
// Installer payload location
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
// If we already have a rocm dependency wired, nothing more to do
libDir, err := AssetsDir()
if err != nil {
return "", fmt.Errorf("unable to lookup lib dir: %w", err)
}
rocmTargetDir := filepath.Join(libDir, "rocm")
if rocmLibUsable(rocmTargetDir) {
return rocmTargetDir, nil
}
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "bin")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return hipLibDir, nil
}
}
// Well known location(s)
if rocmLibUsable(RocmStandardLocation) {
return RocmStandardLocation, nil
}
// Installer payload (if we're running from some other location)
localAppData := os.Getenv("LOCALAPPDATA")
appDir := filepath.Join(localAppData, "Programs", "Ollama")
rocmTargetDir = filepath.Join(appDir, "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ollama installed ROCm at " + rocmTargetDir)
return rocmTargetDir, nil
}
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm v6")
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}

60
gpu/assets.go Normal file
View File

@@ -0,0 +1,60 @@
package gpu
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/jmorganca/ollama/version"
)
func AssetsDir() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
baseDir := filepath.Join(home, ".ollama", "assets")
libDirs, err := os.ReadDir(baseDir)
if err == nil {
for _, d := range libDirs {
if d.Name() == version.Version {
continue
}
// Special case the rocm dependencies, which are handled by the installer
if d.Name() == "rocm" {
continue
}
slog.Debug("stale lib detected, cleaning up " + d.Name())
err = os.RemoveAll(filepath.Join(baseDir, d.Name()))
if err != nil {
slog.Warn(fmt.Sprintf("unable to clean up stale library %s: %s", filepath.Join(baseDir, d.Name()), err))
}
}
}
return filepath.Join(baseDir, version.Version), nil
}
func UpdatePath(dir string) {
if runtime.GOOS == "windows" {
tmpDir := filepath.Dir(dir)
pathComponents := strings.Split(os.Getenv("PATH"), ";")
i := 0
for _, comp := range pathComponents {
if strings.EqualFold(comp, dir) {
return
}
// Remove any other prior paths to our temp dir
if !strings.HasPrefix(strings.ToLower(comp), strings.ToLower(tmpDir)) {
pathComponents[i] = comp
i++
}
}
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
os.Setenv("PATH", newPath)
}
// linux and darwin rely on rpath
}

View File

@@ -24,7 +24,6 @@ import (
type handles struct {
cuda *C.cuda_handle_t
rocm *C.rocm_handle_t
}
var gpuMutex sync.Mutex
@@ -54,39 +53,23 @@ var CudaWindowsGlobs = []string{
"c:\\Windows\\System32\\nvml.dll",
}
var RocmLinuxGlobs = []string{
"/opt/rocm*/lib*/librocm_smi64.so*",
}
var RocmWindowsGlobs = []string{
"c:\\Windows\\System32\\rocm_smi64.dll",
}
// Note: gpuMutex must already be held
func initGPUHandles() {
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
gpuHandles = &handles{nil, nil}
gpuHandles = &handles{nil}
var cudaMgmtName string
var cudaMgmtPatterns []string
var rocmMgmtName string
var rocmMgmtPatterns []string
switch runtime.GOOS {
case "windows":
cudaMgmtName = "nvml.dll"
cudaMgmtPatterns = make([]string, len(CudaWindowsGlobs))
copy(cudaMgmtPatterns, CudaWindowsGlobs)
rocmMgmtName = "rocm_smi64.dll"
rocmMgmtPatterns = make([]string, len(RocmWindowsGlobs))
copy(rocmMgmtPatterns, RocmWindowsGlobs)
case "linux":
cudaMgmtName = "libnvidia-ml.so"
cudaMgmtPatterns = make([]string, len(CudaLinuxGlobs))
copy(cudaMgmtPatterns, CudaLinuxGlobs)
rocmMgmtName = "librocm_smi64.so"
rocmMgmtPatterns = make([]string, len(RocmLinuxGlobs))
copy(rocmMgmtPatterns, RocmLinuxGlobs)
default:
return
}
@@ -101,16 +84,6 @@ func initGPUHandles() {
return
}
}
rocmLibPaths := FindGPULibs(rocmMgmtName, rocmMgmtPatterns)
if len(rocmLibPaths) > 0 {
rocm := LoadROCMMgmt(rocmLibPaths)
if rocm != nil {
slog.Info("Radeon GPU detected")
gpuHandles.rocm = rocm
return
}
}
}
func GetGPUInfo() GpuInfo {
@@ -149,66 +122,10 @@ func GetGPUInfo() GpuInfo {
slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
}
}
} else if AMDDetected() && gpuHandles.rocm != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
ver, err := AMDDriverVersion()
if err == nil {
slog.Info("AMD Driver: " + ver)
} else {
// For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug("error looking up amd driver version: %s", err)
}
gfx := AMDGFXVersions()
tooOld := false
for _, v := range gfx {
if v.Major < 9 {
slog.Info("AMD GPU too old, falling back to CPU " + v.ToGFXString())
tooOld = true
break
}
// TODO - remap gfx strings for unsupporetd minor/patch versions to supported for the same major
// e.g. gfx1034 works if we map it to gfx1030 at runtime
}
if !tooOld {
// TODO - this algo can be shifted over to use sysfs instead of the rocm info library...
C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
// Only one GPU detected and it appears to be an integrated GPU - skip it
slog.Info("ROCm unsupported integrated GPU detected")
} else if memInfo.count > 0 {
if memInfo.igpu_index >= 0 {
// We have multiple GPUs reported, and one of them is an integrated GPU
// so we have to set the env var to bypass it
// If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it
val := os.Getenv("ROCR_VISIBLE_DEVICES")
if val == "" {
devices := []string{}
for i := 0; i < int(memInfo.count); i++ {
if i == int(memInfo.igpu_index) {
continue
}
devices = append(devices, strconv.Itoa(i))
}
val = strings.Join(devices, ",")
os.Setenv("ROCR_VISIBLE_DEVICES", val)
}
slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val))
}
resp.Library = "rocm"
var version C.rocm_version_resp_t
C.rocm_get_version(*gpuHandles.rocm, &version)
verString := C.GoString(version.str)
if version.status == 0 {
resp.Variant = "v" + verString
} else {
slog.Info(fmt.Sprintf("failed to look up ROCm version: %s", verString))
}
C.free(unsafe.Pointer(version.str))
}
} else {
AMDGetGPUInfo(&resp)
if resp.Library != "" {
return resp
}
}
if resp.Library == "" {
@@ -338,23 +255,6 @@ func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t {
return nil
}
func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t {
var resp C.rocm_init_resp_t
resp.rh.verbose = getVerboseState()
for _, libPath := range rocmLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.rocm_init(lib, &resp)
if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load ROCm management library %s: %s", libPath, C.GoString(resp.err)))
C.free(unsafe.Pointer(resp.err))
} else {
return &resp.rh
}
}
return nil
}
func getVerboseState() C.uint16_t {
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
return C.uint16_t(1)

View File

@@ -53,7 +53,6 @@ void cpu_check_ram(mem_info_t *resp);
#endif
#include "gpu_info_cuda.h"
#include "gpu_info_rocm.h"
#endif // __GPU_INFO_H__
#endif // __APPLE__

View File

@@ -124,31 +124,31 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
// When in verbose mode, report more information about
// the card we discover, but don't fail on error
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
if (ret != RSMI_STATUS_SUCCESS) {
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);

View File

@@ -1,198 +0,0 @@
#ifndef __APPLE__
#include "gpu_info_rocm.h"
#include <string.h>
void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
rsmi_status_t ret;
resp->err = NULL;
const int buflen = 256;
char buf[buflen + 1];
int i;
struct lookup {
char *s;
void **p;
} l[] = {
{"rsmi_init", (void *)&resp->rh.rsmi_init},
{"rsmi_shut_down", (void *)&resp->rh.rsmi_shut_down},
{"rsmi_dev_memory_total_get", (void *)&resp->rh.rsmi_dev_memory_total_get},
{"rsmi_dev_memory_usage_get", (void *)&resp->rh.rsmi_dev_memory_usage_get},
{"rsmi_version_get", (void *)&resp->rh.rsmi_version_get},
{"rsmi_num_monitor_devices", (void*)&resp->rh.rsmi_num_monitor_devices},
{"rsmi_dev_id_get", (void*)&resp->rh.rsmi_dev_id_get},
{"rsmi_dev_name_get", (void *)&resp->rh.rsmi_dev_name_get},
{"rsmi_dev_brand_get", (void *)&resp->rh.rsmi_dev_brand_get},
{"rsmi_dev_vendor_name_get", (void *)&resp->rh.rsmi_dev_vendor_name_get},
{"rsmi_dev_vram_vendor_get", (void *)&resp->rh.rsmi_dev_vram_vendor_get},
{"rsmi_dev_serial_number_get", (void *)&resp->rh.rsmi_dev_serial_number_get},
{"rsmi_dev_subsystem_name_get", (void *)&resp->rh.rsmi_dev_subsystem_name_get},
{"rsmi_dev_vbios_version_get", (void *)&resp->rh.rsmi_dev_vbios_version_get},
{NULL, NULL},
};
resp->rh.handle = LOAD_LIBRARY(rocm_lib_path, RTLD_LAZY);
if (!resp->rh.handle) {
char *msg = LOAD_ERR();
snprintf(buf, buflen,
"Unable to load %s library to query for Radeon GPUs: %s\n",
rocm_lib_path, msg);
free(msg);
resp->err = strdup(buf);
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->rh.verbose, "wiring rocm management library functions in %s\n", rocm_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->rh.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->rh.handle, l[i].s);
if (!l[i].p) {
resp->rh.handle = NULL;
char *msg = LOAD_ERR();
LOG(resp->rh.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->rh.handle);
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg);
free(msg);
resp->err = strdup(buf);
return;
}
}
ret = (*resp->rh.rsmi_init)(0);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(resp->rh.verbose, "rsmi_init err: %d\n", ret);
UNLOAD_LIBRARY(resp->rh.handle);
resp->rh.handle = NULL;
snprintf(buf, buflen, "rocm vram init failure: %d", ret);
resp->err = strdup(buf);
}
return;
}
void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
resp->err = NULL;
resp->igpu_index = -1;
uint64_t totalMem = 0;
uint64_t usedMem = 0;
rsmi_status_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("rocm handle not initialized");
return;
}
ret = (*h.rsmi_num_monitor_devices)(&resp->count);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
LOG(h.verbose, "discovered %d ROCm GPU Devices\n", resp->count);
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp->count; i++) {
if (h.verbose) {
// When in verbose mode, report more information about
// the card we discover, but don't fail on error
ret = (*h.rsmi_dev_name_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_name_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm device name: %s\n", i, buf);
}
ret = (*h.rsmi_dev_brand_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_brand_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm brand: %s\n", i, buf);
}
ret = (*h.rsmi_dev_vendor_name_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_vendor_name_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm vendor: %s\n", i, buf);
}
ret = (*h.rsmi_dev_vram_vendor_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_vram_vendor_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm VRAM vendor: %s\n", i, buf);
}
ret = (*h.rsmi_dev_serial_number_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_serial_number_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm S/N: %s\n", i, buf);
}
ret = (*h.rsmi_dev_subsystem_name_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_subsystem_name_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm subsystem name: %s\n", i, buf);
}
ret = (*h.rsmi_dev_vbios_version_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_vbios_version_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm vbios version: %s\n", i, buf);
}
}
// Get total memory - used memory for available memory
ret = (*h.rsmi_dev_memory_total_get)(i, RSMI_MEM_TYPE_VRAM, &totalMem);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "rocm total mem lookup failure: %d", ret);
resp->err = strdup(buf);
return;
}
ret = (*h.rsmi_dev_memory_usage_get)(i, RSMI_MEM_TYPE_VRAM, &usedMem);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "rocm usage mem lookup failure: %d", ret);
resp->err = strdup(buf);
return;
}
LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem);
LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem);
if (totalMem < 1024 * 1024 * 1024) {
// Do not add up integrated GPU memory capacity, it's a bogus 512M, and actually uses system memory
LOG(h.verbose, "[%d] ROCm integrated GPU\n", i);
resp->igpu_index = i;
} else {
resp->total += totalMem;
resp->free += totalMem - usedMem;
}
}
}
void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
const int buflen = 256;
char buf[buflen + 1];
if (h.handle == NULL) {
resp->str = strdup("rocm handle not initialized");
resp->status = 1;
return;
}
rsmi_version_t ver;
rsmi_status_t ret;
ret = h.rsmi_version_get(&ver);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "unexpected response on version lookup %d", ret);
resp->status = 1;
} else {
snprintf(buf, buflen, "%d", ver.major);
resp->status = 0;
}
resp->str = strdup(buf);
}
#endif // __APPLE__

View File

@@ -1,59 +0,0 @@
#ifndef __APPLE__
#ifndef __GPU_INFO_ROCM_H__
#define __GPU_INFO_ROCM_H__
#include "gpu_info.h"
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum rsmi_status_return {
RSMI_STATUS_SUCCESS = 0,
// Other values omitted for now...
} rsmi_status_t;
typedef enum rsmi_memory_type {
RSMI_MEM_TYPE_VRAM = 0,
RSMI_MEM_TYPE_VIS_VRAM,
RSMI_MEM_TYPE_GTT,
} rsmi_memory_type_t;
typedef struct {
uint32_t major;
uint32_t minor;
uint32_t patch;
const char *build;
} rsmi_version_t;
typedef struct rocm_handle {
void *handle;
uint16_t verbose;
rsmi_status_t (*rsmi_init)(uint64_t);
rsmi_status_t (*rsmi_shut_down)(void);
rsmi_status_t (*rsmi_dev_memory_total_get)(uint32_t, rsmi_memory_type_t, uint64_t *);
rsmi_status_t (*rsmi_dev_memory_usage_get)(uint32_t, rsmi_memory_type_t, uint64_t *);
rsmi_status_t (*rsmi_version_get) (rsmi_version_t *version);
rsmi_status_t (*rsmi_num_monitor_devices) (uint32_t *);
rsmi_status_t (*rsmi_dev_id_get)(uint32_t, uint16_t *);
rsmi_status_t (*rsmi_dev_name_get) (uint32_t,char *,size_t);
rsmi_status_t (*rsmi_dev_brand_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_vendor_name_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_vram_vendor_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_serial_number_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_subsystem_name_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_vbios_version_get) (uint32_t, char *, uint32_t);
} rocm_handle_t;
typedef struct rocm_init_resp {
char *err; // If err is non-null handle is invalid
rocm_handle_t rh;
} rocm_init_resp_t;
typedef struct rocm_version_resp {
rsmi_status_t status;
char *str; // Contains version or error string if status != 0
} rocm_version_resp_t;
void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp);
void rocm_check_vram(rocm_handle_t rh, mem_info_t *resp);
void rocm_get_version(rocm_handle_t rh, rocm_version_resp_t *resp);
#endif // __GPU_INFO_ROCM_H__
#endif // __APPLE__