/*   $Source: bitbucket.org:berkeleylab/gasnet.git/other/kinds/gasnet_refkinds.c $
 * Description: GASNet Memory Kinds Implementation
 * Copyright (c) 2020, The Regents of the University of California
 * Terms of use are as specified in license.txt
 */

#define GASNETI_NEED_GASNET_MK_H 1
#include <gasnet_internal.h>
#include <gasnet_kinds_internal.h>

// Convenience macro
#define MK_IMPL(i_mk,short_field) ((i_mk)->_mk_impl->mk_##short_field)

#ifndef gasneti_import_mk
gasneti_MK_t gasneti_import_mk(gex_MK_t _mk) {
  const gasneti_MK_t _real_mk = GASNETI_IMPORT_POINTER(gasneti_MK_t,_mk);
  GASNETI_IMPORT_MAGIC(_real_mk, MK);
  return _real_mk;
}
#endif

#ifndef gasneti_import_mk_nonhost
gasneti_MK_t gasneti_import_mk_nonhost(gex_MK_t _mk) {
  if (_mk == GEX_MK_HOST) {
    gasneti_fatalerror("Invalid use of GEX_MK_HOST where prohibited");
  }
  return gasneti_import_mk(_mk);
}
gasneti_MK_t gasneti_import_mk_nonhost_valid(gex_MK_t mk) {
  gasneti_assert(mk != GEX_MK_INVALID);
  return gasneti_import_mk_nonhost(mk);
}
#endif

#ifndef gasneti_export_mk
gex_MK_t gasneti_export_mk(gasneti_MK_t _real_mk) {
  GASNETI_CHECK_MAGIC(_real_mk, GASNETI_MK_MAGIC);
  return GASNETI_EXPORT_POINTER(gex_MK_t, _real_mk);
}
#endif

gasneti_MK_t gasneti_alloc_mk(
            gasneti_Client_t                 i_client,
            gasneti_mk_impl_t               *mk_impl,
            gex_Flags_t                      flags)
{
  gasneti_assert(mk_impl);

  gasneti_MK_t mk;
  size_t alloc_size = mk_impl->mk_sizeof ? mk_impl->mk_sizeof : sizeof(*mk);
  gasneti_assert(alloc_size >= sizeof(*mk));
  mk = gasneti_calloc(1, alloc_size);

  GASNETI_INIT_MAGIC(mk, GASNETI_MK_MAGIC);
  mk->_cdata = NULL;
  mk->_flags = flags;
  mk->_client = i_client;
  mk->_mk_class = mk_impl->mk_class;
  mk->_mk_impl = mk_impl;
  mk->_mk_conduit = NULL;
  gasneti_weakatomic32_set(&mk->_ref_count, 0, 0);
  return mk;
}

void gasneti_free_mk(gasneti_MK_t mk)
{
  GASNETI_INIT_MAGIC(mk, GASNETI_MK_BAD_MAGIC);
  gasneti_free(mk);
}

void gasneti_destroy_mk(gasneti_MK_t mk, gex_Flags_t flags)
{
  gasneti_assert(0 == gasneti_weakatomic32_read(&mk->_ref_count, 0));
  if (MK_IMPL(mk,destroy)) MK_IMPL(mk,destroy)(mk, flags);
  else                     gasneti_free_mk(mk);
}

int gex_MK_Create(
            gex_MK_t                         *memkind_p,
            gex_Client_t                     e_client,
            const gex_MK_Create_args_t       *args,
            gex_Flags_t                      flags)
{
  gasneti_Client_t client = gasneti_import_client(e_client);
  gasneti_MK_t result = NULL;
  int rc = GASNET_ERR_BAD_ARG;

  GASNETI_TRACE_PRINTF(O,("gex_MK_Create: client='%s' flags=%d",
                          client ? client->_name : "(NULL)", flags));
  GASNETI_CHECK_INJECT();

  if (! client) {
    gasneti_fatalerror("Invalid call to gex_MK_Create with NULL client");
  }
  if (!memkind_p) {
    gasneti_fatalerror("Invalid call to gex_MK_Create with NULL memkind_p");
  }
  if (!args) {
    gasneti_fatalerror("Invalid call to gex_MK_Create with NULL args");
  }
  if (flags) {
    gasneti_fatalerror("Invalid call to gex_MK_Create with non-zero flags");
  }
  if (args->gex_flags) {
    gasneti_fatalerror("Invalid call to gex_MK_Create with non-zero args->gex_flags");
  }

  switch (args->gex_class) {
    case GEX_MK_CLASS_HOST:
      gasneti_fatalerror("Invalid call to gex_MK_Create with GEX_MK_CLASS_HOST");
      break;

    case GEX_MK_CLASS_CUDA_UVA:
    #if GASNET_HAVE_MK_CLASS_CUDA_UVA
      rc = gasneti_MK_Create_cuda_uva(&result, client, args, flags);
    #else
      GASNETI_RETURN_ERRR(BAD_ARG,"This build lacks support for GEX_MK_CLASS_CUDA_UVA");
    #endif
      break;

    case GEX_MK_CLASS_HIP:
    #if GASNET_HAVE_MK_CLASS_HIP
      rc = gasneti_MK_Create_hip(&result, client, args, flags);
    #else
      GASNETI_RETURN_ERRR(BAD_ARG,"This build lacks support for GEX_MK_CLASS_HIP");
    #endif
      break;

    case GEX_MK_CLASS_ZE:
    #if GASNET_HAVE_MK_CLASS_ZE
      rc = gasneti_MK_Create_ze(&result, client, args, flags);
    #else
      GASNETI_RETURN_ERRR(BAD_ARG,"This build lacks support for GEX_MK_CLASS_ZE");
    #endif
      break;

    default: gasneti_unreachable_error(("Unknown MK class: %i",(int)args->gex_class));
  }


  if (! rc) {
    // Sanity checks on per-class initialization
    gasneti_assert(result->_mk_class == args->gex_class);
    gasneti_assert(result->_mk_impl);
    gasneti_assert(MK_IMPL(result,class) == args->gex_class);
    gasneti_assert(MK_IMPL(result,name));
    gasneti_assert(strlen(MK_IMPL(result,name)));
  }

#if GASNETC_MK_CREATE_HOOK
  if (! rc) {
    // Conduit-specific hook, if any
    rc = gasnetc_mk_create_hook(result, client, args, flags);
    if (rc) gasneti_destroy_mk(result, 0); // TODO: any flags to pass through?
  }
#endif

  if (! rc) {
    *memkind_p = gasneti_export_mk(result);
  }

  return rc;
}

void gex_MK_Destroy(
            gex_MK_t                    e_mk,
            gex_Flags_t                 flags)
{
  if (e_mk == GEX_MK_INVALID) {
    gasneti_fatalerror("Invalid call to gex_MK_Destroy(GEX_MK_INVALID)");
  }
  if (e_mk == GEX_MK_HOST) {
    gasneti_fatalerror("Invalid call to gex_MK_Destroy(GEX_MK_HOST)");
  }

  gasneti_MK_t i_mk = gasneti_import_mk(e_mk); // "this"
  gasneti_assert(i_mk->_mk_impl);
  gasneti_assert(MK_IMPL(i_mk,name));

  GASNETI_TRACE_PRINTF(O,("gex_MK_Destroy: memkind=%p, class='%s' flags=%d",
                          (void*)e_mk, MK_IMPL(i_mk,name), flags));
  GASNETI_CHECK_INJECT();

  if (flags) {
    gasneti_fatalerror("Invalid call to gex_MK_Destroy with non-zero flags");
  }

  uint32_t ref_count = gasneti_weakatomic32_read(&i_mk->_ref_count, 0);
  if (ref_count) {
    gasneti_fatalerror("Invalid call to gex_MK_Destroy with ref_count=%u",
                       (unsigned int)ref_count);
  }

#if GASNETC_MK_DESTROY_HOOK
  // Conduit-specific hook, if any
  gasnetc_mk_destroy_hook(i_mk);
#endif

  gasneti_destroy_mk(i_mk, flags);
}

int gasneti_MK_Segment_Create(
            gasneti_Segment_t *i_segment_p,
            gasneti_Client_t  i_client,
            void              *addr,
            uintptr_t         size,
            gex_MK_t          e_mk,
            gex_Flags_t       flags)
{
  gasneti_assert(e_mk != GEX_MK_INVALID); // Caller should have already checked user args

  gasneti_MK_t i_mk = gasneti_import_mk_nonhost(e_mk);

  if (i_mk->_client != i_client) {
    gasneti_fatalerror("Invalid call to gex_Segment_Create with a gex_MK_t from a different client");
  }

  // Class-specific hook, if any
  if (MK_IMPL(i_mk,segment_create)) {
    int rc = MK_IMPL(i_mk,segment_create)(i_segment_p, i_mk, addr, size, flags);
    if (rc) return rc;
  } else {
    GASNETI_RETURN_ERRR(BAD_ARG,"gex_Segment_Create() called on unsupported memory kind");
  }

  gasneti_weakatomic32_increment(&i_mk->_ref_count, 0);
  return GASNET_OK;
}

void gasneti_MK_Segment_Destroy(
            gasneti_Segment_t i_segment)
{
  gasneti_assert(i_segment);

  gasneti_MK_t i_mk = gasneti_import_mk_nonhost(i_segment->_kind);

  // Class-specific hook, if any
  if (MK_IMPL(i_mk,segment_destroy)) {
    MK_IMPL(i_mk,segment_destroy)(i_segment);
  }

  gasneti_weakatomic32_decrement(&i_mk->_ref_count, 0);
}

const char *gasneti_formatmk(
            gex_MK_t e_mk)
{
  if (e_mk == GEX_MK_INVALID) {
    return "GEX_MK_INVALID";
  } else if (e_mk == GEX_MK_HOST) {
    return "GEX_MK_HOST";
  }

  gasneti_MK_t i_mk = gasneti_import_mk_nonhost(e_mk);
  if (MK_IMPL(i_mk,format)) return MK_IMPL(i_mk,format)(i_mk);
  else                      return MK_IMPL(i_mk,name);
}

int gasneti_mk_segment_context_push(gasneti_Segment_t i_segment)
{
  gasneti_assert(i_segment);
  if (i_segment->_kind != GEX_MK_HOST) {
    gasneti_MK_t i_mk = gasneti_import_mk_nonhost(i_segment->_kind);
    if (MK_IMPL(i_mk,segment_context_push)) { // Class-specific hook, if any
      return MK_IMPL(i_mk,segment_context_push)(i_segment);
    }
  }
  return 0;
}

int gasneti_mk_segment_context_pop(gasneti_Segment_t i_segment)
{
  gasneti_assert(i_segment);
  if (i_segment->_kind != GEX_MK_HOST) {
    gasneti_MK_t i_mk = gasneti_import_mk_nonhost(i_segment->_kind);
    if (MK_IMPL(i_mk,segment_context_pop)) { // Class-specific hook, if any
      return MK_IMPL(i_mk,segment_context_pop)(i_segment);
    }
  }
  return 0;
}
