Calling COM code from Go

How to call COM code from Go

In my previous blog I talked about how to call to C functions directly using syscall module, without using Cgo. We can expand this idea a bit further - to call COM objects in Go. As a simple example, let’s see if we can call IMalloc interface implemented in Windows.

To give a bit background, IMalloc is a COM interface that provides malloc/free equivalent functionality, and more importantly, through the process-wide allocator obtained through CoGetMalloc, allow caller/callee to exchange memory ownership, as long as they agree on using the same memory allocator. Simply put, caller from module A could pass a buffer allocated in A, and pass to module B, which can free it, both using process-wide IMalloc allocator. C++ developers are probably familiar with the scenario - if you send a object malloc-ed from one module, and pass to another module which does the free, there is no guaratee that the free would succeed since the two modules might have different allocator used (if they use different CRT library, or both compiled the CRT in statically).

The interface is defined as follows:


interface IMalloc : public IUnknown
{
public:
    virtual void *STDMETHODCALLTYPE Alloc( 
        /* [in] */ SIZE_T cb) = 0;
    
    virtual void *STDMETHODCALLTYPE Realloc( 
        /* [in] */ void *pv,
        /* [in] */ SIZE_T cb) = 0;
    
    virtual void STDMETHODCALLTYPE Free( 
        /* [in] */ void *pv) = 0;
    
    virtual SIZE_T STDMETHODCALLTYPE GetSize( 
        /* [in] */ void *pv) = 0;
    
    virtual int STDMETHODCALLTYPE DidAlloc( 
        void *pv) = 0;
    
    virtual void STDMETHODCALLTYPE HeapMinimize( void) = 0;
    
};

The functions here are pretty straight-forward. If you are curious, you can refer to the documentation

Calling CoGetMalloc should be pretty straight-forward using syscall module as discussed in the previous post. Calling the underlying COM object requires slightly more work. COM is an ABI protocol that supports simple v-table based call dispatch, lifetime management, and threading (probably the most confusing part). To call a COM interface, you only need to call through the v-table, which is a series of function pointer pointing to the underlying code. To define v-table in go, it would look something like this:

type MallocVtbl struct {
    queryInterface uintptr 
    addref uintptr
    release uintptr
    alloc uintptr
    realloc uintptr
    free uintptr
    getSize uintptr
    didAlloc uintptr
    heapMinimize uintptr
} 

Note that the order is very important - it needs to match exactly with the interface itself, and the first 3 is always coming from IUnknown methods.

The v-table pointer can be obtained from the first pointer in the object.

var malloc uintptr;  // This is the IMalloc*

mallocVtblPtr := *(*uintptr)(unsafe.Pointer(malloc)) // Get the vtable pointer

mallocVtbl := (*MallocVtbl)(unsafe.Pointer(mallocVtblPtr))  // Convert to the right MallocVtbl*

Once we have the vtable, getting the underlying function pointer is a simple matter of using the right fields, and then we can call it using syscall like before:

memPtr, _, err := syscall.Syscall(mallocVtbl.alloc, uintptr(2), malloc, uintptr(memSize), uintptr(0))  

You can find the full code below:


package main

import (
    "fmt"
    "log"
    "syscall"
    "unsafe"
)

type MallocVtbl struct {
    queryInterface uintptr 
    addref uintptr
    release uintptr
    alloc uintptr
    realloc uintptr
    free uintptr
    getSize uintptr
    didAlloc uintptr
    heapMinimize uintptr
}   

func main() {
    fmt.Printf("Loading ole32.dll...\n");

    handle, err := syscall.LoadLibrary("ole32.dll")
    if err != nil {
        log.Fatal(err)
        return
    }

    proc, err := syscall.GetProcAddress(handle, "CoGetMalloc")
    if err != nil {
        log.Fatal(err)
    }

    fmt.Printf("Calling CoGetMalloc\n");
    var malloc uintptr;
    ret, _, err := syscall.Syscall(proc, uintptr(2), uintptr(1), uintptr(unsafe.Pointer(&malloc)), uintptr(0))
    if ret < 0 {
        log.Fatal(err)
    }

    fmt.Printf("CoGetMalloc returned %x\n\n", malloc) 

    mallocVtblPtr := *(*uintptr)(unsafe.Pointer(malloc))
    mallocVtbl := (*MallocVtbl)(unsafe.Pointer(mallocVtblPtr))

    const memSize int = 100
    fmt.Printf("Calling IMalloc::Alloc(%v)\n", memSize)
    memPtr, _, err := syscall.Syscall(mallocVtbl.alloc, uintptr(2), malloc, uintptr(memSize), uintptr(0))
    
    if memPtr == 0 {
        log.Fatal(err)
    }
       
    fmt.Printf("IMalloc::Alloc returned %x\n\n", memPtr) 
 
    fmt.Printf("Calling IMalloc::GetSize with %x\n", memPtr)

    returnedSize, _, err := syscall.Syscall(mallocVtbl.getSize, uintptr(2), malloc, memPtr, uintptr(0))
    fmt.Printf("IMalloc::GetSize returned %v\n\n", returnedSize)

    fmt.Printf("Calling IMalloc::Free with %x\n", memPtr);
    syscall.Syscall(mallocVtbl.free, uintptr(2), malloc, memPtr, uintptr(0))
    fmt.Printf("IMalloc::Free succeeded\n\n");

    fmt.Printf("Calling IMalloc::DidAlloc with %x\n", memPtr)
    ret, _, err = syscall.Syscall(mallocVtbl.didAlloc, uintptr(2), malloc, memPtr, uintptr(0))
    didAlloc := bool(ret != 0)
    fmt.Printf("IMalloc::DidAlloc returned %v\n\n", didAlloc);

    fmt.Printf("Calling IMalloc::Release\n")
    ret, _, err = syscall.Syscall(mallocVtbl.release, uintptr(1), malloc, uintptr(0), uintptr(0))
    fmt.Printf("IMalloc::Release returned %v\n\n", ret);
}