boost/compute/lambda/context.hpp
//---------------------------------------------------------------------------//
// Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
//
// Distributed under the Boost Software License, Version 1.0
// See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt
//
// See http://boostorg.github.com/compute for more information.
//---------------------------------------------------------------------------//
#ifndef BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
#define BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
#include <boost/proto/core.hpp>
#include <boost/proto/context.hpp>
#include <boost/type_traits.hpp>
#include <boost/preprocessor/repetition.hpp>
#include <boost/compute/config.hpp>
#include <boost/compute/function.hpp>
#include <boost/compute/lambda/result_of.hpp>
#include <boost/compute/lambda/functional.hpp>
#include <boost/compute/type_traits/result_of.hpp>
#include <boost/compute/type_traits/type_name.hpp>
#include <boost/compute/detail/meta_kernel.hpp>
namespace boost {
namespace compute {
namespace lambda {
namespace mpl = boost::mpl;
namespace proto = boost::proto;
#define BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(tag, op) \
template<class LHS, class RHS> \
void operator()(tag, const LHS &lhs, const RHS &rhs) \
{ \
if(proto::arity_of<LHS>::value > 0){ \
stream << '('; \
proto::eval(lhs, *this); \
stream << ')'; \
} \
else { \
proto::eval(lhs, *this); \
} \
\
stream << op; \
\
if(proto::arity_of<RHS>::value > 0){ \
stream << '('; \
proto::eval(rhs, *this); \
stream << ')'; \
} \
else { \
proto::eval(rhs, *this); \
} \
}
// lambda expression context
template<class Args>
struct context : proto::callable_context<context<Args> >
{
typedef void result_type;
typedef Args args_tuple;
// create a lambda context for kernel with args
context(boost::compute::detail::meta_kernel &kernel, const Args &args_)
: stream(kernel),
args(args_)
{
}
// handle terminals
template<class T>
void operator()(proto::tag::terminal, const T &x)
{
// terminal values in lambda expressions are always literals
stream << stream.lit(x);
}
// handle placeholders
template<int I>
void operator()(proto::tag::terminal, placeholder<I>)
{
stream << boost::get<I>(args);
}
// handle functions
#define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG(z, n, unused) \
BOOST_PP_COMMA_IF(n) BOOST_PP_CAT(const Arg, n) BOOST_PP_CAT(&arg, n)
#define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION(z, n, unused) \
template<class F, BOOST_PP_ENUM_PARAMS(n, class Arg)> \
void operator()( \
proto::tag::function, \
const F &function, \
BOOST_PP_REPEAT(n, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG, ~) \
) \
{ \
proto::value(function).apply(*this, BOOST_PP_ENUM_PARAMS(n, arg)); \
}
BOOST_PP_REPEAT_FROM_TO(1, BOOST_COMPUTE_MAX_ARITY, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION, ~)
#undef BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION
// operators
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::plus, '+')
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::minus, '-')
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::multiplies, '*')
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::divides, '/')
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::modulus, '%')
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less, '<')
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater, '>')
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less_equal, "<=")
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater_equal, ">=")
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::equal_to, "==")
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::not_equal_to, "!=")
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_and, "&&")
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_or, "||")
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_and, '&')
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_or, '|')
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_xor, '^')
BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::assign, '=')
// subscript operator
template<class LHS, class RHS>
void operator()(proto::tag::subscript, const LHS &lhs, const RHS &rhs)
{
proto::eval(lhs, *this);
stream << '[';
proto::eval(rhs, *this);
stream << ']';
}
// ternary conditional operator
template<class Pred, class Arg1, class Arg2>
void operator()(proto::tag::if_else_, const Pred &p, const Arg1 &x, const Arg2 &y)
{
proto::eval(p, *this);
stream << '?';
proto::eval(x, *this);
stream << ':';
proto::eval(y, *this);
}
boost::compute::detail::meta_kernel &stream;
Args args;
};
namespace detail {
template<class Expr, class Arg>
struct invoked_unary_expression
{
typedef typename ::boost::compute::result_of<Expr(Arg)>::type result_type;
invoked_unary_expression(const Expr &expr, const Arg &arg)
: m_expr(expr),
m_arg(arg)
{
}
Expr m_expr;
Arg m_arg;
};
template<class Expr, class Arg>
boost::compute::detail::meta_kernel&
operator<<(boost::compute::detail::meta_kernel &kernel,
const invoked_unary_expression<Expr, Arg> &expr)
{
context<boost::tuple<Arg> > ctx(kernel, boost::make_tuple(expr.m_arg));
proto::eval(expr.m_expr, ctx);
return kernel;
}
template<class Expr, class Arg1, class Arg2>
struct invoked_binary_expression
{
typedef typename ::boost::compute::result_of<Expr(Arg1, Arg2)>::type result_type;
invoked_binary_expression(const Expr &expr,
const Arg1 &arg1,
const Arg2 &arg2)
: m_expr(expr),
m_arg1(arg1),
m_arg2(arg2)
{
}
Expr m_expr;
Arg1 m_arg1;
Arg2 m_arg2;
};
template<class Expr, class Arg1, class Arg2>
boost::compute::detail::meta_kernel&
operator<<(boost::compute::detail::meta_kernel &kernel,
const invoked_binary_expression<Expr, Arg1, Arg2> &expr)
{
context<boost::tuple<Arg1, Arg2> > ctx(
kernel,
boost::make_tuple(expr.m_arg1, expr.m_arg2)
);
proto::eval(expr.m_expr, ctx);
return kernel;
}
} // end detail namespace
// forward declare domain
struct domain;
// lambda expression wrapper
template<class Expr>
struct expression : proto::extends<Expr, expression<Expr>, domain>
{
typedef proto::extends<Expr, expression<Expr>, domain> base_type;
BOOST_PROTO_EXTENDS_USING_ASSIGN(expression)
expression(const Expr &expr = Expr())
: base_type(expr)
{
}
// result_of protocol
template<class Signature>
struct result
{
};
template<class This>
struct result<This()>
{
typedef
typename ::boost::compute::lambda::result_of<Expr>::type type;
};
template<class This, class Arg>
struct result<This(Arg)>
{
typedef
typename ::boost::compute::lambda::result_of<
Expr,
typename boost::tuple<Arg>
>::type type;
};
template<class This, class Arg1, class Arg2>
struct result<This(Arg1, Arg2)>
{
typedef typename
::boost::compute::lambda::result_of<
Expr,
typename boost::tuple<Arg1, Arg2>
>::type type;
};
template<class Arg>
detail::invoked_unary_expression<expression<Expr>, Arg>
operator()(const Arg &x) const
{
return detail::invoked_unary_expression<expression<Expr>, Arg>(*this, x);
}
template<class Arg1, class Arg2>
detail::invoked_binary_expression<expression<Expr>, Arg1, Arg2>
operator()(const Arg1 &x, const Arg2 &y) const
{
return detail::invoked_binary_expression<
expression<Expr>,
Arg1,
Arg2
>(*this, x, y);
}
// function<> conversion operator
template<class R, class A1>
operator function<R(A1)>() const
{
using ::boost::compute::detail::meta_kernel;
std::stringstream source;
::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
source << "inline " << type_name<R>() << " lambda"
<< ::boost::compute::detail::generate_argument_list<R(A1)>('x')
<< "{\n"
<< " return " << meta_kernel::expr_to_string((*this)(arg1)) << ";\n"
<< "}\n";
return make_function_from_source<R(A1)>("lambda", source.str());
}
template<class R, class A1, class A2>
operator function<R(A1, A2)>() const
{
using ::boost::compute::detail::meta_kernel;
std::stringstream source;
::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
::boost::compute::detail::meta_kernel_variable<A1> arg2("y");
source << "inline " << type_name<R>() << " lambda"
<< ::boost::compute::detail::generate_argument_list<R(A1, A2)>('x')
<< "{\n"
<< " return " << meta_kernel::expr_to_string((*this)(arg1, arg2)) << ";\n"
<< "}\n";
return make_function_from_source<R(A1, A2)>("lambda", source.str());
}
};
// lambda expression domain
struct domain : proto::domain<proto::generator<expression> >
{
};
} // end lambda namespace
} // end compute namespace
} // end boost namespace
#endif // BOOST_COMPUTE_LAMBDA_CONTEXT_HPP