#![cfg_attr(target_arch = "spirv", no_std)]
#![feature(lang_items)]
#![feature(register_attr)]
#![register_attr(spirv)]

use spirv_std::glam::{const_vec3, Vec2, Vec2Swizzles, Vec3, Vec4, Vec4Swizzles};
use spirv_std::storage_class::{Input, Output, UniformConstant};

#[cfg(target_arch = "spirv")]
use spirv_std::num_traits::Float;

const RAYMARCH_MAX_STEPS: i32 = 100;
const RAYMARCH_MAX_DIST: f32 = 100.0;
const RAYMARCH_SURF_DIST: f32 = 0.01;

fn get_normal(point: Vec3) -> Vec3 {
	let dist = get_distance(point);
	
	// Calculate distance around point
	let e = Vec2::new(0.01, 0.0);
	let normal = Vec3::splat(dist) - Vec3::new(
		get_distance(point - e.xyy()),
		get_distance(point - e.yxy()),
		get_distance(point - e.yyx()),
	);
	
	normal.normalize()
}

fn get_light(point: Vec3, time: f32) -> f32 {
	// Create a light that is 5 units above the sphere
	let mut light_pos = Vec3::new(0.0, 5.0, 6.0);
	light_pos.x += time.sin() * 2.0;
	light_pos.z += time.cos() * 2.0;
	
	let light = (light_pos - point).normalize();
 	let normal = get_normal(point);
 	
	let mut diffuse = normal.dot(light).max(0.0).min(1.0);
	let dist = ray_march(point + normal * RAYMARCH_SURF_DIST * 2.0, light);
	
	// Darken area that is in shadow
	if dist  < (light_pos - point).length() {
		diffuse *= 0.1;
	}
	
	diffuse
	
}


fn get_distance(point: Vec3) -> f32 {
	// Create a sphere, with x, y, z being position and w being its radius.
	let sphere = Vec4::new(0.0, 1.0, 6.0, 1.0);

	// Distances to the objects
	let sphere_dist = (point - sphere.xyz()).length() - sphere.w;
	let plane_dist = point.y;
	
	// Return the minimum of both distances
	f32::min(plane_dist, sphere_dist)
}

fn ray_march(ray_orig: Vec3, ray_dir: Vec3) -> f32 {
	let mut dist = 0.0;
	
	// A `for _ in a..b` loop can be used in later versions.
	// See: https://github.com/EmbarkStudios/rust-gpu/pull/493
	let mut i = 0;
	while i < RAYMARCH_MAX_STEPS {
		let point = ray_orig + ray_dir * Vec3::splat(dist);
		let scene_dist = get_distance(point);
		dist += scene_dist;
		
		// It seems `||` can not be used yet, so the expressions are split up.
		if dist > RAYMARCH_MAX_DIST {
			break;
		}
		
		if dist < RAYMARCH_SURF_DIST {
			break;
		}
		
		i += 1;
 	}
	
	dist
}

#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main_fs(
	#[spirv(frag_coord)] frag_coord : Input<Vec4>,
	u_time : UniformConstant<f32>,
	u_resolution : UniformConstant<Vec2>,
	mut output : Output<Vec4>
) {
	let coord = frag_coord.load();
	let res = u_resolution.load();
	let time = u_time.load();
		
	// Centered origin by mapping UV to [-0.5, 0.5]
	let mut uv = coord.xy() - Vec2::splat(0.5) * res;
	// Correct for aspect ratio
	uv /= res.y;
	
	// Setup camera
	let ray_orig = Vec3::new(0.0, 1.0, 0.0);
	let ray_dir = Vec3::new(uv.x, uv.y, 1.0).normalize();
	
	let mut dist = ray_march(ray_orig, ray_dir);
	
	let mut point = ray_orig + ray_dir * dist;
	
	let diff_light = get_light(point, time);	
	
	let color = Vec3::splat(diff_light);
	
    output.store(color.extend(1.0))
}